diff --git a/SqlDatabaseVectorSearch/Program.cs b/SqlDatabaseVectorSearch/Program.cs index 8832861..244729a 100644 --- a/SqlDatabaseVectorSearch/Program.cs +++ b/SqlDatabaseVectorSearch/Program.cs @@ -30,7 +30,7 @@ builder.Services.AddMemoryCache(); // Semantic Kernel is used to generate embeddings and to reformulate questions taking into account all the previous interactions, // so that embeddings themselves can be generated more accurately. builder.Services.AddKernel() - .AddAzureOpenAITextEmbeddingGeneration(aiSettings.Embedding.Deployment, aiSettings.Embedding.Endpoint, aiSettings.Embedding.ApiKey) + .AddAzureOpenAITextEmbeddingGeneration(aiSettings.Embedding.Deployment, aiSettings.Embedding.Endpoint, aiSettings.Embedding.ApiKey, dimensions: aiSettings.Embedding.Dimensions) .AddAzureOpenAIChatCompletion(aiSettings.ChatCompletion.Deployment, aiSettings.ChatCompletion.Endpoint, aiSettings.ChatCompletion.ApiKey); builder.Services.AddScoped(); diff --git a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs index 013a60b..207689d 100644 --- a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs +++ b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs @@ -34,7 +34,12 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene await using var command = sqlConnection.CreateCommand(); command.Transaction = (SqlTransaction)transaction; - command.CommandText = "INSERT INTO Documents2 (Id, [Name], CreationDate) OUTPUT INSERTED.Id VALUES (@Id, @Name, @CreationDate)"; + command.CommandText = """ + INSERT INTO Documents2 (Id, [Name], CreationDate) + OUTPUT INSERTED.Id + VALUES (@Id, @Name, @CreationDate); + """; + command.Parameters.AddWithValue("@Id", documentId.GetValueOrDefault(Guid.NewGuid())); command.Parameters.AddWithValue("@Name", name); command.Parameters.AddWithValue("@CreationDate", timeProvider.GetUtcNow()); @@ -51,7 +56,11 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene { command.Parameters.Clear(); - command.CommandText = "INSERT INTO DocumentChunks2 (DocumentId, [Index], Content, Embedding) VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR(1536)))"; + command.CommandText = $""" + INSERT INTO DocumentChunks2 (DocumentId, [Index], Content, Embedding) + VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR({embedding.Length}))); + """; + command.Parameters.AddWithValue("@DocumentId", documentId); command.Parameters.AddWithValue("@Index", index++); command.Parameters.AddWithValue("@Content", paragraph); @@ -70,11 +79,15 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene await sqlConnection.OpenAsync(); await using var command = sqlConnection.CreateCommand(); - command.CommandText = "SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks2 WHERE DocumentId = Documents2.Id) FROM Documents2 ORDER BY [Name]"; + command.CommandText = """ + SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks2 + WHERE DocumentId = Documents2.Id) + FROM Documents2 ORDER BY [Name]; + """; var documents = new List(); - using var reader = await command.ExecuteReaderAsync(); + await using var reader = await command.ExecuteReaderAsync(); while (await reader.ReadAsync()) { var id = reader.GetGuid(0); @@ -93,12 +106,17 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene await sqlConnection.OpenAsync(); await using var command = sqlConnection.CreateCommand(); - command.CommandText = "SELECT Id, [Index], Content FROM DocumentChunks2 WHERE DocumentId = @DocumentId ORDER BY [Index]"; + command.CommandText = """ + SELECT Id, [Index], Content + FROM DocumentChunks2 WHERE DocumentId = @DocumentId + ORDER BY [Index]; + """; + command.Parameters.AddWithValue("@DocumentId", documentId); var documentChunks = new List(); - using var reader = await command.ExecuteReaderAsync(); + await using var reader = await command.ExecuteReaderAsync(); while (await reader.ReadAsync()) { var id = reader.GetGuid(0); @@ -116,11 +134,15 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene await sqlConnection.OpenAsync(); await using var command = sqlConnection.CreateCommand(); - command.CommandText = "SELECT [Index], Content, CAST(Embedding AS NVARCHAR(MAX)) FROM DocumentChunks2 WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId"; + command.CommandText = """ + "SELECT [Index], Content, CAST(Embedding AS NVARCHAR(MAX)) + FROM DocumentChunks2 WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId; + """; + command.Parameters.AddWithValue("@DocumentChunkId", documentChunkId); command.Parameters.AddWithValue("@DocumentId", documentId); - using var reader = await command.ExecuteReaderAsync(); + await using var reader = await command.ExecuteReaderAsync(); if (reader.HasRows && await reader.ReadAsync()) { var index = reader.GetInt32(0); @@ -140,7 +162,7 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene await sqlConnection.OpenAsync(); } - using var command = sqlConnection.CreateCommand(); + await using var command = sqlConnection.CreateCommand(); command.Transaction = transaction as SqlTransaction; command.CommandText = "DELETE FROM Documents2 WHERE Id = @DocumentId"; @@ -160,13 +182,18 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene await sqlConnection.OpenAsync(); await using var command = sqlConnection.CreateCommand(); - command.CommandText = "SELECT TOP (@MaxRelevantChunks) Content FROM DocumentChunks2 ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR(1536)))"; + command.CommandText = $""" + SELECT TOP (@MaxRelevantChunks) Content + FROM DocumentChunks2 + ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR({questionEmbedding.Length}))); + """; + command.Parameters.AddWithValue("@MaxRelevantChunks", appSettings.MaxRelevantChunks); command.Parameters.AddWithValue("@QuestionEmbedding", JsonSerializer.Serialize(questionEmbedding)); var chunks = new List(); - using var reader = await command.ExecuteReaderAsync(); + await using var reader = await command.ExecuteReaderAsync(); while (await reader.ReadAsync()) { var content = reader.GetString(0); diff --git a/SqlDatabaseVectorSearch/Settings/AzureOpenAISettings.cs b/SqlDatabaseVectorSearch/Settings/AzureOpenAISettings.cs index 5c9abb1..e85d51d 100644 --- a/SqlDatabaseVectorSearch/Settings/AzureOpenAISettings.cs +++ b/SqlDatabaseVectorSearch/Settings/AzureOpenAISettings.cs @@ -4,7 +4,7 @@ public class AzureOpenAISettings { public required ServiceSettings ChatCompletion { get; init; } - public required ServiceSettings Embedding { get; init; } + public required EmbeddingServiceSettings Embedding { get; init; } } public class ServiceSettings @@ -15,3 +15,8 @@ public class ServiceSettings public required string ApiKey { get; init; } } + +public class EmbeddingServiceSettings : ServiceSettings +{ + public int? Dimensions { get; set; } +} diff --git a/SqlDatabaseVectorSearch/appsettings.json b/SqlDatabaseVectorSearch/appsettings.json index 5c5c42b..75f72ec 100644 --- a/SqlDatabaseVectorSearch/appsettings.json +++ b/SqlDatabaseVectorSearch/appsettings.json @@ -11,7 +11,8 @@ "Embedding": { "Endpoint": "", "Deployment": "", - "ApiKey": "" + "ApiKey": "", + "Dimensions": null } }, "AppSettings": {