diff --git a/.pipelines/mssql-pipelines.yml b/.pipelines/mssql-pipelines.yml index bdfb5e37e6..d9a2201695 100644 --- a/.pipelines/mssql-pipelines.yml +++ b/.pipelines/mssql-pipelines.yml @@ -177,7 +177,7 @@ jobs: # for the linux job above. data-source.connection-string: Server=(localdb)\MSSQLLocalDB;Persist Security Info=False;Integrated Security=True;MultipleActiveResultSets=False;Connection Timeout=30;TrustServerCertificate=True; InstallerUrl: https://download.microsoft.com/download/7/c/1/7c14e92e-bdcb-4f89-b7cf-93543e7112d1/SqlLocalDB.msi - SqlVersionCode: '15.0' + SqlVersionCode: '17.0' steps: - template: templates/mssql-test-steps.yml @@ -200,7 +200,7 @@ jobs: # for the linux job above. data-source.connection-string: Server=(localdb)\MSSQLLocalDB;Persist Security Info=False;Integrated Security=True;MultipleActiveResultSets=False;Connection Timeout=30;TrustServerCertificate=True; InstallerUrl: https://download.microsoft.com/download/7/c/1/7c14e92e-bdcb-4f89-b7cf-93543e7112d1/SqlLocalDB.msi - SqlVersionCode: '15.0' + SqlVersionCode: '17.0' steps: - template: templates/mssql-test-steps.yml diff --git a/config-generators/mssql-commands.txt b/config-generators/mssql-commands.txt index e3f2541983..99e138b846 100644 --- a/config-generators/mssql-commands.txt +++ b/config-generators/mssql-commands.txt @@ -16,6 +16,8 @@ add Broker --config "dab-config.MsSql.json" --source brokers --permissions "anon add WebsiteUser --config "dab-config.MsSql.json" --source website_users --permissions "anonymous:create,read,delete,update" add WebsiteUser_MM --config "dab-config.MsSql.json" --source website_users_mm --graphql "websiteuser_mm:websiteusers_mm" --permissions "anonymous:*" add SupportedType --config "dab-config.MsSql.json" --source type_table --permissions "anonymous:create,read,delete,update" +add VectorType --config "dab-config.MsSql.json" --source vector_type_table --rest true --graphql false --permissions "anonymous:create,read,delete,update" +update VectorType --config "dab-config.MsSql.json" --permissions "authenticated:create,read,delete,update" add stocks_price --config "dab-config.MsSql.json" --source stocks_price --permissions "authenticated:create,read,update,delete" update stocks_price --config "dab-config.MsSql.json" --permissions "anonymous:read" update stocks_price --config "dab-config.MsSql.json" --permissions "TestNestedFilterFieldIsNull_ColumnForbidden:read" --fields.exclude "price" diff --git a/scripts/start-mssql-server.bash b/scripts/start-mssql-server.bash index 5268a5b807..4b1f0a0a42 100644 --- a/scripts/start-mssql-server.bash +++ b/scripts/start-mssql-server.bash @@ -27,7 +27,7 @@ echo "forceencryption = 1" >> $CERT_DIR/mssql.conf cat $CERT_DIR/mssql.conf # Start mssql-server by volume mounting the cert, key and conf files. -docker run -e "ACCEPT_EULA=Y" -e "MSSQL_SA_PASSWORD=$DOCKER_SQL_PASS" -p 1433:1433 --name customerdb -h customerdb -v $CERT_DIR/mssql.conf:/var/opt/mssql/mssql.conf -v $CERT_DIR/mssql.pem:/var/opt/mssql/mssql.pem -v $CERT_DIR/mssql.key:/var/opt/mssql/mssql.key -d mcr.microsoft.com/mssql/server:2019-latest +docker run -e "ACCEPT_EULA=Y" -e "MSSQL_SA_PASSWORD=$DOCKER_SQL_PASS" -p 1433:1433 --name customerdb -h customerdb -v $CERT_DIR/mssql.conf:/var/opt/mssql/mssql.conf -v $CERT_DIR/mssql.pem:/var/opt/mssql/mssql.pem -v $CERT_DIR/mssql.key:/var/opt/mssql/mssql.key -d mcr.microsoft.com/mssql/server:2025-latest sleep 30 docker logs customerdb diff --git a/src/Core/Models/SqlTypeConstants.cs b/src/Core/Models/SqlTypeConstants.cs index 6315e88936..c2d970e439 100644 --- a/src/Core/Models/SqlTypeConstants.cs +++ b/src/Core/Models/SqlTypeConstants.cs @@ -48,6 +48,7 @@ public static class SqlTypeConstants { "datetime2", true }, // SqlDbType.DateTime2 { "datetimeoffset", true }, // SqlDbType.DateTimeOffset { "", false }, // SqlDbType.Udt and SqlDbType.Structured provided by SQL as empty strings (unsupported) - { "numeric", true} // Not present in SqlDbType, however can be returned by sql functions like LAG and should map to decimal. + { "numeric", true}, // Not present in SqlDbType, however can be returned by sql functions like LAG and should map to decimal. + { "vector", true } // SqlDbType.Vector }; } diff --git a/src/Core/Resolvers/MsSqlQueryBuilder.cs b/src/Core/Resolvers/MsSqlQueryBuilder.cs index 7adedc64d4..7769dcf94b 100644 --- a/src/Core/Resolvers/MsSqlQueryBuilder.cs +++ b/src/Core/Resolvers/MsSqlQueryBuilder.cs @@ -666,8 +666,7 @@ AND ty.name IN N'hierarchyid', N'sql_variant', N'xml', - N'rowversion', - N'vector' + N'rowversion' ) ) THEN 1 ELSE 0 @@ -712,8 +711,7 @@ AND ty.name IN N'hierarchyid', N'sql_variant', N'xml', - N'rowversion', - N'vector' + N'rowversion' ) ) ) diff --git a/src/Core/Resolvers/MsSqlQueryExecutor.cs b/src/Core/Resolvers/MsSqlQueryExecutor.cs index 2b5807c8fd..b4f84757b3 100644 --- a/src/Core/Resolvers/MsSqlQueryExecutor.cs +++ b/src/Core/Resolvers/MsSqlQueryExecutor.cs @@ -18,6 +18,7 @@ using Azure.Identity; using Microsoft.AspNetCore.Http; using Microsoft.Data.SqlClient; +using Microsoft.Data.SqlTypes; using Microsoft.Extensions.Logging; namespace Azure.DataApiBuilder.Core.Resolvers @@ -695,6 +696,19 @@ public override SqlCommand PrepareDbCommand( parameter.Size = parameterEntry.Value.Length.Value; } + // if sqldbtype is vector then set the value as an SqlVector object + if (parameter.SqlDbType is SqlDbType.Vector) + { + List values = new(); + foreach (float val in (Array)parameter.Value) + { + values.Add(val); + } + + SqlVector value = new(values.ToArray()); + parameter.Value = value; + } + cmd.Parameters.Add(parameter); } } diff --git a/src/Core/Resolvers/QueryExecutor.cs b/src/Core/Resolvers/QueryExecutor.cs index d96b19af38..981252bf8f 100644 --- a/src/Core/Resolvers/QueryExecutor.cs +++ b/src/Core/Resolvers/QueryExecutor.cs @@ -13,6 +13,7 @@ using Azure.DataApiBuilder.Core.Models; using Azure.DataApiBuilder.Service.Exceptions; using Microsoft.AspNetCore.Http; +using Microsoft.Data.SqlTypes; using Microsoft.Extensions.Logging; using Polly; using Polly.Retry; @@ -502,7 +503,7 @@ public async Task { if (!ConfigProvider.GetConfig().MaxResponseSizeLogicEnabled()) { - dbResultSetRow.Columns.Add(columnName, dbDataReader[columnName]); + dbResultSetRow.Columns.Add(columnName, GetColumnInformation(dbDataReader, columnName)); } else { @@ -554,7 +555,7 @@ public DbResultSet { if (!ConfigProvider.GetConfig().MaxResponseSizeLogicEnabled()) { - dbResultSetRow.Columns.Add(columnName, dbDataReader[columnName]); + dbResultSetRow.Columns.Add(columnName, GetColumnInformation(dbDataReader, columnName)); } else { @@ -822,7 +823,7 @@ internal int StreamDataIntoDbResultSetRow(DbDataReader dbDataReader, DbResultSet { dataRead = columnSize; ValidateSize(availableBytes, dataRead); - dbResultSetRow.Columns.Add(columnName, dbDataReader[columnName]); + dbResultSetRow.Columns.Add(columnName, GetColumnInformation(dbDataReader, columnName)); } return dataRead; @@ -885,6 +886,22 @@ private void ValidateSize(long availableSizeBytes, long sizeToBeReadBytes) } } + /// + /// Helper function to get column information from the DbDataReader and handle special cases like SqlVector. + /// + /// + /// + /// + private static object GetColumnInformation(DbDataReader dbDataReader, string columnName) + { + if (dbDataReader[columnName] is SqlVector columnValue) + { + return columnValue.Memory; + } + + return dbDataReader[columnName]; + } + internal virtual void AddDbExecutionTimeToMiddlewareContext(long time) { HttpContext? httpContext = HttpContextAccessor?.HttpContext; diff --git a/src/Core/Resolvers/Sql Query Structures/BaseSqlQueryStructure.cs b/src/Core/Resolvers/Sql Query Structures/BaseSqlQueryStructure.cs index 99a5b1e72c..c6593047a3 100644 --- a/src/Core/Resolvers/Sql Query Structures/BaseSqlQueryStructure.cs +++ b/src/Core/Resolvers/Sql Query Structures/BaseSqlQueryStructure.cs @@ -4,6 +4,7 @@ using System.Data; using System.Globalization; using System.Net; +using System.Text.Json; using Azure.DataApiBuilder.Auth; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; @@ -452,10 +453,50 @@ protected static object ParseParamAsSystemType(string param, Type systemType) "Guid" => Guid.Parse(param), "TimeOnly" => TimeOnly.Parse(param), "TimeSpan" => TimeOnly.Parse(param), + "Single[]" => ParseArrayIntoSystemType(param, systemType), _ => throw new NotSupportedException($"{systemType.Name} is not supported") }; } + /// + /// Takes the array of the parameter we are going to parse and converts each element to the specified system type. + /// + /// + /// + /// + /// + /// + private static object ParseArrayIntoSystemType(string param, Type systemType) + { + Type typeOfArray; + switch (systemType.Name) + { + case "Single[]": + typeOfArray = typeof(Single); + break; + + default: + throw new NotSupportedException($"{systemType.Name} is not supported"); + } + + try + { + List list = new(); + object[] values = JsonSerializer.Deserialize(param) ?? Array.Empty(); + for (int i = 0; i < values.Length; i++) + { + string stringValue = values[i]?.ToString() ?? string.Empty; + values[i] = ParseParamAsSystemType(stringValue, typeOfArray); + } + + return values; + } + catch + { + throw new FormatException($"Expected an array for {systemType.Name} but got an unexpected value"); + } + } + /// /// Very similar to GQLArgumentToDictParams but only extracts the argument names from /// the specified field which means that the method does not need a middleware context diff --git a/src/Core/Services/MetadataProviders/MsSqlMetadataProvider.cs b/src/Core/Services/MetadataProviders/MsSqlMetadataProvider.cs index 166bcd1b35..00c1727384 100644 --- a/src/Core/Services/MetadataProviders/MsSqlMetadataProvider.cs +++ b/src/Core/Services/MetadataProviders/MsSqlMetadataProvider.cs @@ -15,6 +15,7 @@ using Azure.DataApiBuilder.Service.Exceptions; using Azure.DataApiBuilder.Service.GraphQLBuilder; using Microsoft.Data.SqlClient; +using Microsoft.Data.SqlTypes; using Microsoft.Extensions.Logging; using static Azure.DataApiBuilder.Service.GraphQLBuilder.GraphQLNaming; @@ -120,6 +121,15 @@ protected override void PopulateColumnDefinitionWithHasDefaultAndDbType( columnDefinition.DbType = TypeHelper.GetDbTypeFromSystemType(columnDefinition.SystemType); string sqlDbTypeName = (string)columnInfo["DATA_TYPE"]; + + if (columnDefinition.SystemType == typeof(SqlVector)) + { + sqlDbTypeName = "vector"; // Currently the "DATA_TYPE" column returns "varbinary" for vector type columns. This is a known issue https://learn.microsoft.com/en-us/sql/t-sql/data-types/vector-data-type?view=sql-server-ver17&tabs=csharp#known-issues + columnDefinition.IsArrayType = true; + columnDefinition.ElementSystemType = typeof(Single); + columnDefinition.SystemType = columnDefinition.ElementSystemType.MakeArrayType(); + } + if (Enum.TryParse(sqlDbTypeName, ignoreCase: true, out SqlDbType sqlDbType)) { // The DbType enum in .NET does not distinguish between VarChar and NVarChar. Both are mapped to DbType.String. diff --git a/src/Core/Services/OpenAPI/OpenApiDocumentor.cs b/src/Core/Services/OpenAPI/OpenApiDocumentor.cs index 979da52eb6..e66968de13 100644 --- a/src/Core/Services/OpenAPI/OpenApiDocumentor.cs +++ b/src/Core/Services/OpenAPI/OpenApiDocumentor.cs @@ -1484,12 +1484,18 @@ private static OpenApiSchema CreateComponentSchema( if (metadataProvider.TryGetBackingColumn(entityName, field, out string? backingColumnValue) && !string.IsNullOrEmpty(backingColumnValue)) { string typeMetadata = string.Empty; + string subTypeMetadata = string.Empty; string formatMetadata = string.Empty; string? fieldDescription = null; if (dbObject.SourceDefinition.Columns.TryGetValue(backingColumnValue, out ColumnDefinition? columnDef)) { typeMetadata = TypeHelper.GetJsonDataTypeFromSystemType(columnDef.SystemType).ToString().ToLower(); + + if (string.Equals(typeMetadata, JsonDataType.Array.ToString().ToLower(), StringComparison.OrdinalIgnoreCase)) + { + subTypeMetadata = TypeHelper.GetJsonDataTypeFromSystemType(columnDef.ElementSystemType!).ToString().ToLower(); + } } if (entityConfig?.Fields != null) @@ -1502,7 +1508,8 @@ private static OpenApiSchema CreateComponentSchema( { Type = typeMetadata, Format = formatMetadata, - Description = fieldDescription + Description = fieldDescription, + Items = !string.IsNullOrWhiteSpace(subTypeMetadata) ? new OpenApiSchema() { Type = subTypeMetadata } : null }); } } diff --git a/src/Core/Services/TypeHelper.cs b/src/Core/Services/TypeHelper.cs index 0a95744abb..db47fe6ce9 100644 --- a/src/Core/Services/TypeHelper.cs +++ b/src/Core/Services/TypeHelper.cs @@ -8,6 +8,7 @@ using Azure.DataApiBuilder.Core.Services.OpenAPI; using Azure.DataApiBuilder.Service.Exceptions; using HotChocolate.Language; +using Microsoft.Data.SqlTypes; using Microsoft.OData.Edm; namespace Azure.DataApiBuilder.Core.Services @@ -46,7 +47,8 @@ public static class TypeHelper [typeof(byte[])] = DbType.Binary, [typeof(TimeOnly)] = DbType.Time, [typeof(TimeSpan)] = DbType.Time, - [typeof(object)] = DbType.Object + [typeof(object)] = DbType.Object, + [typeof(SqlVector)] = DbType.Single }; /// @@ -77,7 +79,8 @@ public static class TypeHelper [typeof(TimeOnly)] = JsonDataType.String, [typeof(object)] = JsonDataType.Object, [typeof(DateTime)] = JsonDataType.String, - [typeof(DateTimeOffset)] = JsonDataType.String + [typeof(DateTimeOffset)] = JsonDataType.String, + [typeof(Single[])] = JsonDataType.Array }; /// @@ -111,7 +114,8 @@ public static class TypeHelper [SqlDbType.TinyInt] = typeof(byte), [SqlDbType.UniqueIdentifier] = typeof(Guid), [SqlDbType.VarBinary] = typeof(byte[]), - [SqlDbType.VarChar] = typeof(string) + [SqlDbType.VarChar] = typeof(string), + [SqlDbType.Vector] = typeof(float) }; private static Dictionary _sqlDbDateTimeTypeToDbType = new() diff --git a/src/Service.Tests/DatabaseSchema-MsSql.sql b/src/Service.Tests/DatabaseSchema-MsSql.sql index 4e87394aee..91bdd69f59 100644 --- a/src/Service.Tests/DatabaseSchema-MsSql.sql +++ b/src/Service.Tests/DatabaseSchema-MsSql.sql @@ -40,6 +40,7 @@ DROP TABLE IF EXISTS stocks; DROP TABLE IF EXISTS comics; DROP TABLE IF EXISTS brokers; DROP TABLE IF EXISTS type_table; +DROP TABLE IF EXISTS vector_type_table; DROP TABLE IF EXISTS trees; DROP TABLE IF EXISTS fungi; DROP TABLE IF EXISTS empty_table; @@ -232,6 +233,12 @@ CREATE TABLE type_table( uuid_types uniqueidentifier DEFAULT newid() ); +CREATE TABLE vector_type_table( + id int IDENTITY(5001, 1) PRIMARY KEY, + vector_data vector(3), + vector_data_max vector(1998) +); + CREATE TABLE trees ( treeId int PRIMARY KEY, species varchar(max), @@ -608,6 +615,23 @@ VALUES INSERT INTO type_table(id, uuid_types) values(10, 'D1D021A8-47B4-4AE4-B718-98E89C41A161'); SET IDENTITY_INSERT type_table OFF +SET IDENTITY_INSERT vector_type_table ON +INSERT INTO vector_type_table(id, vector_data) +VALUES + (1, '[0.5, 0.25, 0.75]'), + (2, '[1.5, -2.5, 3.5]'), + (3, NULL), + (4, '[1.0, 2.0, 3.0]'), + (5, '[4.0, 5.0, 6.0]'), + (6, '[7.0, 8.0, 9.0]'); + +INSERT INTO vector_type_table(id, vector_data_max) +VALUES (7, CAST('[' + ( + SELECT STRING_AGG(CAST(value AS NVARCHAR(MAX)), ',') WITHIN GROUP (ORDER BY value) + FROM GENERATE_SERIES(1, 1998) +) + ']' AS vector(1998))); +SET IDENTITY_INSERT vector_type_table OFF + SET IDENTITY_INSERT sales ON INSERT INTO sales(id, item_name, subtotal, tax) VALUES (1, 'Watch', 249.00, 20.59), (2, 'Montior', 120.50, 11.12); SET IDENTITY_INSERT sales OFF diff --git a/src/Service.Tests/SqlTests/RestApiTests/MsSqlRestVectorTypesTests.cs b/src/Service.Tests/SqlTests/RestApiTests/MsSqlRestVectorTypesTests.cs new file mode 100644 index 0000000000..15f496e902 --- /dev/null +++ b/src/Service.Tests/SqlTests/RestApiTests/MsSqlRestVectorTypesTests.cs @@ -0,0 +1,320 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.SqlTests.RestApiTests +{ + /// + /// Tests for SQL Server vector column support via REST endpoints (read and write). + /// Verifies that vector columns are returned as JSON arrays of numbers via REST GET requests + /// and can be inserted/updated/deleted via REST POST/PUT/PATCH/DELETE requests. + /// This mirrors the pattern used for PostgreSQL array types in + /// . + /// NOTE: The vector data type requires SQL Server 2025 / Azure SQL. + /// + [TestClass, TestCategory(TestCategory.MSSQL)] + public class MsSqlRestVectorTypesTests : SqlTestBase + { + private const string VECTOR_TYPE_REST_PATH = "api/VectorType"; + + /// + /// Tolerance used when comparing the single-precision components of a vector, + /// since vector(N) stores 32-bit floats which may not round-trip exactly through JSON. + /// + private const double VECTOR_COMPONENT_DELTA = 0.0001; + + [ClassInitialize] + public static async Task SetupAsync(TestContext context) + { + DatabaseEngine = TestCategory.MSSQL; + await InitializeTestFixture(); + } + + #region Read Tests + + [DataTestMethod] + [DataRow(VECTOR_TYPE_REST_PATH, 7, new[] { 0.5f, 0.25f, 0.75f }, DisplayName = "GET for Vector data type")] + [DataRow($"{VECTOR_TYPE_REST_PATH}/id/2", 1, new[] { 1.5f, -2.5f, 3.5f }, DisplayName = "GET for Vector data type by primary key")] + [DataRow($"{VECTOR_TYPE_REST_PATH}/id/3", 1, null, DisplayName = "GET for Vector data type with null vector")] + public async Task GetVectorTypeList(string vectorRestPath, int expectedItems, float[] expectedValues) + { + HttpResponseMessage response = await HttpClient.GetAsync(vectorRestPath); + Assert.AreEqual(HttpStatusCode.OK, response.StatusCode); + + string body = await response.Content.ReadAsStringAsync(); + JsonElement root = JsonDocument.Parse(body).RootElement; + JsonElement items = root.GetProperty("value"); + + Assert.AreEqual(expectedItems, items.GetArrayLength(), $"Expected {expectedItems} items, got {items.GetArrayLength()}"); + + // Records are ordered by the primary key ascending, so the first record is id = 1. + JsonElement first = items[0]; + AssertVectorEquals(first.GetProperty("vector_data"), expectedValues); + } + + /// + /// GET /api/VectorType/id/7 - Verify that a vector using the maximum supported dimension count (1998) + /// round-trips through REST and is returned with the correct number of dimensions. + /// + [TestMethod] + public async Task GetVectorTypeWithMaxDimensions() + { + HttpResponseMessage response = await HttpClient.GetAsync($"{VECTOR_TYPE_REST_PATH}/id/7"); + Assert.AreEqual(HttpStatusCode.OK, response.StatusCode); + + string body = await response.Content.ReadAsStringAsync(); + JsonElement value = JsonDocument.Parse(body).RootElement.GetProperty("value")[0]; + + Assert.AreEqual(7, value.GetProperty("id").GetInt32()); + + JsonElement maxVector = value.GetProperty("vector_data_max"); + Assert.AreEqual(JsonValueKind.Array, maxVector.ValueKind, "Expected the maximum-dimension vector to be serialized as a JSON array."); + Assert.AreEqual(1998, maxVector.GetArrayLength(), "Expected the maximum-dimension vector to have 1998 components."); + + int i = 1; + foreach (JsonElement vectorVal in maxVector.EnumerateArray()) + { + Assert.AreEqual(i, vectorVal.GetDouble(), VECTOR_COMPONENT_DELTA); + i++; + } + } + + #endregion + + #region Write Tests + + /// + /// POST /api/VectorType - Verify that a new record with a vector value can be inserted and is + /// returned (and persisted) as a JSON array. + /// + [DataTestMethod] + [DataRow("{ \"vector_data\": [0.125, 0.25, 0.5] }", new[] { 0.125f, 0.25f, 0.5f }, true, DisplayName = "Insert valid vector")] + [DataRow("{ \"vector_data\": null }", null, true, DisplayName = "Insert valid null vector")] + [DataRow("{ \"vector_data\": [5e-1, 2.5e-1, 7.5e-1] }", new[] { 0.5f, 0.25f, 0.75f }, true, DisplayName = "Insert valid vector with scientific notation")] + [DataRow("{ \"vector_data\": [\"0.5\", \"0.25\", \"0.75\"] }", new[] { 0.5f, 0.25f, 0.75f }, true, DisplayName = "Insert valid vector with numbers as string values")] + [DataRow("{ \"vector_data\": [1.25, 2.25, 3.25, 4.25] }", null, false, DisplayName = "Insert invalid vector with more dimensions than allowed")] + [DataRow("{ \"vector_data\": [\"not\", \"a\", \"number\"] }", null, false, DisplayName = "Insert invalid vector with invalid values")] + public async Task InsertVectorType(string requestBody, float[] expectedValue, bool expectedSuccess) + { + HttpResponseMessage postResponse = await HttpClient.PostAsync( + VECTOR_TYPE_REST_PATH, + new StringContent(requestBody, Encoding.UTF8, "application/json")); + + if (expectedSuccess) + { + Assert.AreEqual(HttpStatusCode.Created, postResponse.StatusCode); + + JsonElement postElement = JsonDocument.Parse(await postResponse.Content.ReadAsStringAsync()) + .RootElement.GetProperty("value")[0]; + int newId = postElement.GetProperty("id").GetInt32(); + + // Confirm the value was persisted by reading it back. + JsonElement readBack = await GetRecordByIdAsync(newId); + AssertVectorEquals(readBack.GetProperty("vector_data"), expectedValue); + await DeleteVectorType(newId); + } + else + { + Assert.IsFalse(postResponse.IsSuccessStatusCode, "Expected that inserting vector should fail."); + } + } + + /// + /// PUT Verify that an existing record's vector value is replaced (full update). + /// + [TestMethod] + public async Task PutVectorType_Update() + { + // Change vector value + float[] expected = new[] { 9.5f, 8.5f, 7.5f }; + string requestBody = "{ \"vector_data\": [9.5, 8.5, 7.5] }"; + + HttpResponseMessage response = await HttpClient.PutAsync( + $"{VECTOR_TYPE_REST_PATH}/id/4", + new StringContent(requestBody, Encoding.UTF8, "application/json")); + Assert.AreEqual(HttpStatusCode.OK, response.StatusCode); + + JsonElement updated = JsonDocument.Parse(await response.Content.ReadAsStringAsync()) + .RootElement.GetProperty("value")[0]; + Assert.AreEqual(4, updated.GetProperty("id").GetInt32()); + + JsonElement readBack = await GetRecordByIdAsync(4); + AssertVectorEquals(readBack.GetProperty("vector_data"), expected); + + // Restore vector value to original + expected = new[] { 1.0f, 2.0f, 3.0f }; + requestBody = "{ \"vector_data\": [1.0, 2.0, 3.0] }"; + + HttpResponseMessage restoreResponse = await HttpClient.PutAsync( + $"{VECTOR_TYPE_REST_PATH}/id/4", + new StringContent(requestBody, Encoding.UTF8, "application/json")); + Assert.AreEqual(HttpStatusCode.OK, restoreResponse.StatusCode); + + JsonElement restoreUpdated = JsonDocument.Parse(await restoreResponse.Content.ReadAsStringAsync()) + .RootElement.GetProperty("value")[0]; + Assert.AreEqual(4, restoreUpdated.GetProperty("id").GetInt32()); + + JsonElement restoreReadBack = await GetRecordByIdAsync(4); + AssertVectorEquals(restoreReadBack.GetProperty("vector_data"), expected); + } + + /// + /// PATCH Verify that an existing record's vector value is updated. + /// + [TestMethod] + public async Task PatchVectorType_Update() + { + // Change vector value + float[] expected = new[] { 1.25f, 2.25f, 3.25f }; + string requestBody = "{ \"vector_data\": [1.25, 2.25, 3.25] }"; + + HttpResponseMessage response = await HttpClient.PatchAsync( + $"{VECTOR_TYPE_REST_PATH}/id/4", + new StringContent(requestBody, Encoding.UTF8, "application/json")); + Assert.AreEqual(HttpStatusCode.OK, response.StatusCode); + + JsonElement updated = JsonDocument.Parse(await response.Content.ReadAsStringAsync()) + .RootElement.GetProperty("value")[0]; + Assert.AreEqual(4, updated.GetProperty("id").GetInt32()); + + JsonElement readBack = await GetRecordByIdAsync(4); + AssertVectorEquals(readBack.GetProperty("vector_data"), expected); + + // Restore vector value to original + expected = new[] { 1.0f, 2.0f, 3.0f }; + requestBody = "{ \"vector_data\": [1.0, 2.0, 3.0] }"; + + HttpResponseMessage restoreResponse = await HttpClient.PutAsync( + $"{VECTOR_TYPE_REST_PATH}/id/4", + new StringContent(requestBody, Encoding.UTF8, "application/json")); + Assert.AreEqual(HttpStatusCode.OK, restoreResponse.StatusCode); + + JsonElement restoreUpdated = JsonDocument.Parse(await restoreResponse.Content.ReadAsStringAsync()) + .RootElement.GetProperty("value")[0]; + Assert.AreEqual(4, restoreUpdated.GetProperty("id").GetInt32()); + + JsonElement restoreReadBack = await GetRecordByIdAsync(4); + AssertVectorEquals(restoreReadBack.GetProperty("vector_data"), expected); + } + + #endregion + + #region Query Option Tests + + [DataTestMethod] + [DataRow("?$filter=vector_data%20eq%201", DisplayName = "Fail GET with $filter on vector column")] + [DataRow("?$orderby=vector_data ASC", DisplayName = "Fail GET with $orderby on vector column")] + public async Task ArgumentsOnVectorColumnFail(string queryOptions) + { + HttpResponseMessage response = await HttpClient.GetAsync($"{VECTOR_TYPE_REST_PATH}{queryOptions}"); + Assert.AreEqual(HttpStatusCode.BadRequest, response.StatusCode, "Query options on vector columns should be rejected."); + } + + /// + /// GET /api/VectorType?$first=2&$orderby=id - Verify that pagination works for an entity that has a + /// vector column. The first request returns a page plus a nextLink, and issuing a second request using + /// the $after token extracted from that nextLink also succeeds. + /// + [TestMethod] + public async Task FindWithFirstThenAfterPaginationSucceedsVectorType() + { + // First page: limit to two records, ordered by primary key for a deterministic cursor. + HttpResponseMessage firstPageResponse = await HttpClient.GetAsync($"{VECTOR_TYPE_REST_PATH}?$first=2&$orderby=id"); + Assert.AreEqual(HttpStatusCode.OK, firstPageResponse.StatusCode); + + JsonElement firstPageRoot = JsonDocument.Parse(await firstPageResponse.Content.ReadAsStringAsync()).RootElement; + Assert.AreEqual(2, firstPageRoot.GetProperty("value").GetArrayLength(), "Expected the first page to contain exactly two records."); + Assert.IsTrue(firstPageRoot.TryGetProperty("nextLink", out JsonElement nextLinkElement), "Expected a nextLink on the first page."); + + // Extract the $after token from the nextLink and use it to request the next page. + string afterToken = ExtractAfterToken(nextLinkElement.GetString()); + Assert.IsFalse(string.IsNullOrEmpty(afterToken), "Expected a non-empty $after token in the nextLink."); + + HttpResponseMessage secondPageResponse = await HttpClient.GetAsync($"{VECTOR_TYPE_REST_PATH}?$first=2&$orderby=id&$after={afterToken}"); + Assert.AreEqual(HttpStatusCode.OK, secondPageResponse.StatusCode, "Expected the request using the $after token to succeed."); + + JsonElement secondPageRoot = JsonDocument.Parse(await secondPageResponse.Content.ReadAsStringAsync()).RootElement; + Assert.IsTrue(secondPageRoot.GetProperty("value").GetArrayLength() >= 1, "Expected the second page to contain at least one record."); + } + + #endregion + + #region Helpers + + /// + /// DELETE /api/VectorType/id/6 - Verify that a record with a vector column can be deleted and is + /// no longer retrievable. + /// + private static async Task DeleteVectorType(int id) + { + HttpResponseMessage deleteResponse = await HttpClient.DeleteAsync($"{VECTOR_TYPE_REST_PATH}/id/{id}"); + Assert.AreEqual(HttpStatusCode.NoContent, deleteResponse.StatusCode); + } + + /// + /// Fetches a single VectorType record by its primary key and returns the record element. + /// + private static async Task GetRecordByIdAsync(int id) + { + HttpResponseMessage response = await HttpClient.GetAsync($"{VECTOR_TYPE_REST_PATH}/id/{id}"); + Assert.AreEqual(HttpStatusCode.OK, response.StatusCode); + string body = await response.Content.ReadAsStringAsync(); + return JsonDocument.Parse(body).RootElement.GetProperty("value")[0].Clone(); + } + + /// + /// Extracts the raw (URL-encoded) value of the $after query parameter from a pagination nextLink. + /// The token is returned exactly as emitted by the engine so it can be replayed verbatim in a + /// follow-up request without any re-encoding. + /// + private static string ExtractAfterToken(string nextLink) + { + if (string.IsNullOrEmpty(nextLink)) + { + return string.Empty; + } + + const string afterParam = "$after="; + int afterIndex = nextLink.IndexOf(afterParam, StringComparison.Ordinal); + if (afterIndex < 0) + { + return string.Empty; + } + + string afterToken = nextLink.Substring(afterIndex + afterParam.Length); + int ampersandIndex = afterToken.IndexOf('&'); + return ampersandIndex >= 0 ? afterToken.Substring(0, ampersandIndex) : afterToken; + } + + /// + /// Asserts that the given JSON element is an array whose components match the expected vector + /// within . + /// + private static void AssertVectorEquals(JsonElement actual, float[] expected) + { + if (expected == null) + { + Assert.AreEqual(JsonValueKind.Null, actual.ValueKind, "Expected a null vector, but got a non-null value."); + return; + } + + Assert.AreEqual(expected.Length, actual.GetArrayLength(), "Vector dimension mismatch."); + + int i = 0; + foreach (JsonElement element in actual.EnumerateArray()) + { + Assert.AreEqual(expected[i], element.GetDouble(), VECTOR_COMPONENT_DELTA, $"Vector component at expected {expected[i]} and got {element.GetDouble()}."); + i++; + } + } + + #endregion + } +} diff --git a/src/Service.Tests/dab-config.MsSql.json b/src/Service.Tests/dab-config.MsSql.json index 6a41d8ee13..4de4f52b5f 100644 --- a/src/Service.Tests/dab-config.MsSql.json +++ b/src/Service.Tests/dab-config.MsSql.json @@ -1802,6 +1802,58 @@ } ] }, + "VectorType": { + "source": { + "object": "vector_type_table", + "type": "table" + }, + "graphql": { + "enabled": false, + "type": { + "singular": "VectorType", + "plural": "VectorTypes" + } + }, + "rest": { + "enabled": true + }, + "permissions": [ + { + "role": "anonymous", + "actions": [ + { + "action": "create" + }, + { + "action": "read" + }, + { + "action": "delete" + }, + { + "action": "update" + } + ] + }, + { + "role": "authenticated", + "actions": [ + { + "action": "create" + }, + { + "action": "read" + }, + { + "action": "delete" + }, + { + "action": "update" + } + ] + } + ] + }, "stocks_price": { "source": { "object": "stocks_price", @@ -4022,4 +4074,4 @@ ] } } -} \ No newline at end of file +}