From 2dff0aae55e0c1cce4dacf7f5616b17336e3efc0 Mon Sep 17 00:00:00 2001 From: Marco Minerva Date: Mon, 30 Sep 2024 17:53:59 +0200 Subject: [PATCH] Add dimensions parameter for embeddings; reformat SQL Updated Program.cs to include dimensions parameter for AddAzureOpenAITextEmbeddingGeneration sourced from aiSettings.Embedding.Dimensions. Reformatted SQL command texts in VectorSearchService.cs for better readability. Introduced EmbeddingServiceSettings class in AzureOpenAISettings.cs to allow optional dimensions configuration. Updated appsettings.json to include new Dimensions property under Embedding section. --- SqlDatabaseVectorSearch/Program.cs | 2 +- .../Services/VectorSearchService.cs | 49 ++++++++++++++----- .../Settings/AzureOpenAISettings.cs | 7 ++- SqlDatabaseVectorSearch/appsettings.json | 3 +- 4 files changed, 47 insertions(+), 14 deletions(-) diff --git a/SqlDatabaseVectorSearch/Program.cs b/SqlDatabaseVectorSearch/Program.cs index 8832861..244729a 100644 --- a/SqlDatabaseVectorSearch/Program.cs +++ b/SqlDatabaseVectorSearch/Program.cs @@ -30,7 +30,7 @@ builder.Services.AddMemoryCache(); // Semantic Kernel is used to generate embeddings and to reformulate questions taking into account all the previous interactions, // so that embeddings themselves can be generated more accurately. builder.Services.AddKernel() - .AddAzureOpenAITextEmbeddingGeneration(aiSettings.Embedding.Deployment, aiSettings.Embedding.Endpoint, aiSettings.Embedding.ApiKey) + .AddAzureOpenAITextEmbeddingGeneration(aiSettings.Embedding.Deployment, aiSettings.Embedding.Endpoint, aiSettings.Embedding.ApiKey, dimensions: aiSettings.Embedding.Dimensions) .AddAzureOpenAIChatCompletion(aiSettings.ChatCompletion.Deployment, aiSettings.ChatCompletion.Endpoint, aiSettings.ChatCompletion.ApiKey); builder.Services.AddScoped(); diff --git a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs index 013a60b..207689d 100644 --- a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs +++ b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs @@ -34,7 +34,12 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene await using var command = sqlConnection.CreateCommand(); command.Transaction = (SqlTransaction)transaction; - command.CommandText = "INSERT INTO Documents2 (Id, [Name], CreationDate) OUTPUT INSERTED.Id VALUES (@Id, @Name, @CreationDate)"; + command.CommandText = """ + INSERT INTO Documents2 (Id, [Name], CreationDate) + OUTPUT INSERTED.Id + VALUES (@Id, @Name, @CreationDate); + """; + command.Parameters.AddWithValue("@Id", documentId.GetValueOrDefault(Guid.NewGuid())); command.Parameters.AddWithValue("@Name", name); command.Parameters.AddWithValue("@CreationDate", timeProvider.GetUtcNow()); @@ -51,7 +56,11 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene { command.Parameters.Clear(); - command.CommandText = "INSERT INTO DocumentChunks2 (DocumentId, [Index], Content, Embedding) VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR(1536)))"; + command.CommandText = $""" + INSERT INTO DocumentChunks2 (DocumentId, [Index], Content, Embedding) + VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR({embedding.Length}))); + """; + command.Parameters.AddWithValue("@DocumentId", documentId); command.Parameters.AddWithValue("@Index", index++); command.Parameters.AddWithValue("@Content", paragraph); @@ -70,11 +79,15 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene await sqlConnection.OpenAsync(); await using var command = sqlConnection.CreateCommand(); - command.CommandText = "SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks2 WHERE DocumentId = Documents2.Id) FROM Documents2 ORDER BY [Name]"; + command.CommandText = """ + SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks2 + WHERE DocumentId = Documents2.Id) + FROM Documents2 ORDER BY [Name]; + """; var documents = new List(); - using var reader = await command.ExecuteReaderAsync(); + await using var reader = await command.ExecuteReaderAsync(); while (await reader.ReadAsync()) { var id = reader.GetGuid(0); @@ -93,12 +106,17 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene await sqlConnection.OpenAsync(); await using var command = sqlConnection.CreateCommand(); - command.CommandText = "SELECT Id, [Index], Content FROM DocumentChunks2 WHERE DocumentId = @DocumentId ORDER BY [Index]"; + command.CommandText = """ + SELECT Id, [Index], Content + FROM DocumentChunks2 WHERE DocumentId = @DocumentId + ORDER BY [Index]; + """; + command.Parameters.AddWithValue("@DocumentId", documentId); var documentChunks = new List(); - using var reader = await command.ExecuteReaderAsync(); + await using var reader = await command.ExecuteReaderAsync(); while (await reader.ReadAsync()) { var id = reader.GetGuid(0); @@ -116,11 +134,15 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene await sqlConnection.OpenAsync(); await using var command = sqlConnection.CreateCommand(); - command.CommandText = "SELECT [Index], Content, CAST(Embedding AS NVARCHAR(MAX)) FROM DocumentChunks2 WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId"; + command.CommandText = """ + "SELECT [Index], Content, CAST(Embedding AS NVARCHAR(MAX)) + FROM DocumentChunks2 WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId; + """; + command.Parameters.AddWithValue("@DocumentChunkId", documentChunkId); command.Parameters.AddWithValue("@DocumentId", documentId); - using var reader = await command.ExecuteReaderAsync(); + await using var reader = await command.ExecuteReaderAsync(); if (reader.HasRows && await reader.ReadAsync()) { var index = reader.GetInt32(0); @@ -140,7 +162,7 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene await sqlConnection.OpenAsync(); } - using var command = sqlConnection.CreateCommand(); + await using var command = sqlConnection.CreateCommand(); command.Transaction = transaction as SqlTransaction; command.CommandText = "DELETE FROM Documents2 WHERE Id = @DocumentId"; @@ -160,13 +182,18 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene await sqlConnection.OpenAsync(); await using var command = sqlConnection.CreateCommand(); - command.CommandText = "SELECT TOP (@MaxRelevantChunks) Content FROM DocumentChunks2 ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR(1536)))"; + command.CommandText = $""" + SELECT TOP (@MaxRelevantChunks) Content + FROM DocumentChunks2 + ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR({questionEmbedding.Length}))); + """; + command.Parameters.AddWithValue("@MaxRelevantChunks", appSettings.MaxRelevantChunks); command.Parameters.AddWithValue("@QuestionEmbedding", JsonSerializer.Serialize(questionEmbedding)); var chunks = new List(); - using var reader = await command.ExecuteReaderAsync(); + await using var reader = await command.ExecuteReaderAsync(); while (await reader.ReadAsync()) { var content = reader.GetString(0); diff --git a/SqlDatabaseVectorSearch/Settings/AzureOpenAISettings.cs b/SqlDatabaseVectorSearch/Settings/AzureOpenAISettings.cs index 5c9abb1..e85d51d 100644 --- a/SqlDatabaseVectorSearch/Settings/AzureOpenAISettings.cs +++ b/SqlDatabaseVectorSearch/Settings/AzureOpenAISettings.cs @@ -4,7 +4,7 @@ public class AzureOpenAISettings { public required ServiceSettings ChatCompletion { get; init; } - public required ServiceSettings Embedding { get; init; } + public required EmbeddingServiceSettings Embedding { get; init; } } public class ServiceSettings @@ -15,3 +15,8 @@ public class ServiceSettings public required string ApiKey { get; init; } } + +public class EmbeddingServiceSettings : ServiceSettings +{ + public int? Dimensions { get; set; } +} diff --git a/SqlDatabaseVectorSearch/appsettings.json b/SqlDatabaseVectorSearch/appsettings.json index 5c5c42b..75f72ec 100644 --- a/SqlDatabaseVectorSearch/appsettings.json +++ b/SqlDatabaseVectorSearch/appsettings.json @@ -11,7 +11,8 @@ "Embedding": { "Endpoint": "", "Deployment": "", - "ApiKey": "" + "ApiKey": "", + "Dimensions": null } }, "AppSettings": {