Add dimensions parameter for embeddings; reformat SQL

Updated Program.cs to include dimensions parameter for AddAzureOpenAITextEmbeddingGeneration sourced from aiSettings.Embedding.Dimensions. Reformatted SQL command texts in VectorSearchService.cs for better readability. Introduced EmbeddingServiceSettings class in AzureOpenAISettings.cs to allow optional dimensions configuration. Updated appsettings.json to include new Dimensions property under Embedding section.
This commit is contained in:
Marco Minerva
2024-09-30 17:53:59 +02:00
parent 3e95251485
commit 2dff0aae55
4 changed files with 47 additions and 14 deletions
+1 -1
View File
@@ -30,7 +30,7 @@ builder.Services.AddMemoryCache();
// Semantic Kernel is used to generate embeddings and to reformulate questions taking into account all the previous interactions, // Semantic Kernel is used to generate embeddings and to reformulate questions taking into account all the previous interactions,
// so that embeddings themselves can be generated more accurately. // so that embeddings themselves can be generated more accurately.
builder.Services.AddKernel() builder.Services.AddKernel()
.AddAzureOpenAITextEmbeddingGeneration(aiSettings.Embedding.Deployment, aiSettings.Embedding.Endpoint, aiSettings.Embedding.ApiKey) .AddAzureOpenAITextEmbeddingGeneration(aiSettings.Embedding.Deployment, aiSettings.Embedding.Endpoint, aiSettings.Embedding.ApiKey, dimensions: aiSettings.Embedding.Dimensions)
.AddAzureOpenAIChatCompletion(aiSettings.ChatCompletion.Deployment, aiSettings.ChatCompletion.Endpoint, aiSettings.ChatCompletion.ApiKey); .AddAzureOpenAIChatCompletion(aiSettings.ChatCompletion.Deployment, aiSettings.ChatCompletion.Endpoint, aiSettings.ChatCompletion.ApiKey);
builder.Services.AddScoped<ChatService>(); builder.Services.AddScoped<ChatService>();
@@ -34,7 +34,12 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
await using var command = sqlConnection.CreateCommand(); await using var command = sqlConnection.CreateCommand();
command.Transaction = (SqlTransaction)transaction; command.Transaction = (SqlTransaction)transaction;
command.CommandText = "INSERT INTO Documents2 (Id, [Name], CreationDate) OUTPUT INSERTED.Id VALUES (@Id, @Name, @CreationDate)"; command.CommandText = """
INSERT INTO Documents2 (Id, [Name], CreationDate)
OUTPUT INSERTED.Id
VALUES (@Id, @Name, @CreationDate);
""";
command.Parameters.AddWithValue("@Id", documentId.GetValueOrDefault(Guid.NewGuid())); command.Parameters.AddWithValue("@Id", documentId.GetValueOrDefault(Guid.NewGuid()));
command.Parameters.AddWithValue("@Name", name); command.Parameters.AddWithValue("@Name", name);
command.Parameters.AddWithValue("@CreationDate", timeProvider.GetUtcNow()); command.Parameters.AddWithValue("@CreationDate", timeProvider.GetUtcNow());
@@ -51,7 +56,11 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
{ {
command.Parameters.Clear(); command.Parameters.Clear();
command.CommandText = "INSERT INTO DocumentChunks2 (DocumentId, [Index], Content, Embedding) VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR(1536)))"; command.CommandText = $"""
INSERT INTO DocumentChunks2 (DocumentId, [Index], Content, Embedding)
VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR({embedding.Length})));
""";
command.Parameters.AddWithValue("@DocumentId", documentId); command.Parameters.AddWithValue("@DocumentId", documentId);
command.Parameters.AddWithValue("@Index", index++); command.Parameters.AddWithValue("@Index", index++);
command.Parameters.AddWithValue("@Content", paragraph); command.Parameters.AddWithValue("@Content", paragraph);
@@ -70,11 +79,15 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
await sqlConnection.OpenAsync(); await sqlConnection.OpenAsync();
await using var command = sqlConnection.CreateCommand(); await using var command = sqlConnection.CreateCommand();
command.CommandText = "SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks2 WHERE DocumentId = Documents2.Id) FROM Documents2 ORDER BY [Name]"; command.CommandText = """
SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks2
WHERE DocumentId = Documents2.Id)
FROM Documents2 ORDER BY [Name];
""";
var documents = new List<Document>(); var documents = new List<Document>();
using var reader = await command.ExecuteReaderAsync(); await using var reader = await command.ExecuteReaderAsync();
while (await reader.ReadAsync()) while (await reader.ReadAsync())
{ {
var id = reader.GetGuid(0); var id = reader.GetGuid(0);
@@ -93,12 +106,17 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
await sqlConnection.OpenAsync(); await sqlConnection.OpenAsync();
await using var command = sqlConnection.CreateCommand(); await using var command = sqlConnection.CreateCommand();
command.CommandText = "SELECT Id, [Index], Content FROM DocumentChunks2 WHERE DocumentId = @DocumentId ORDER BY [Index]"; command.CommandText = """
SELECT Id, [Index], Content
FROM DocumentChunks2 WHERE DocumentId = @DocumentId
ORDER BY [Index];
""";
command.Parameters.AddWithValue("@DocumentId", documentId); command.Parameters.AddWithValue("@DocumentId", documentId);
var documentChunks = new List<DocumentChunk>(); var documentChunks = new List<DocumentChunk>();
using var reader = await command.ExecuteReaderAsync(); await using var reader = await command.ExecuteReaderAsync();
while (await reader.ReadAsync()) while (await reader.ReadAsync())
{ {
var id = reader.GetGuid(0); var id = reader.GetGuid(0);
@@ -116,11 +134,15 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
await sqlConnection.OpenAsync(); await sqlConnection.OpenAsync();
await using var command = sqlConnection.CreateCommand(); await using var command = sqlConnection.CreateCommand();
command.CommandText = "SELECT [Index], Content, CAST(Embedding AS NVARCHAR(MAX)) FROM DocumentChunks2 WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId"; command.CommandText = """
"SELECT [Index], Content, CAST(Embedding AS NVARCHAR(MAX))
FROM DocumentChunks2 WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId;
""";
command.Parameters.AddWithValue("@DocumentChunkId", documentChunkId); command.Parameters.AddWithValue("@DocumentChunkId", documentChunkId);
command.Parameters.AddWithValue("@DocumentId", documentId); command.Parameters.AddWithValue("@DocumentId", documentId);
using var reader = await command.ExecuteReaderAsync(); await using var reader = await command.ExecuteReaderAsync();
if (reader.HasRows && await reader.ReadAsync()) if (reader.HasRows && await reader.ReadAsync())
{ {
var index = reader.GetInt32(0); var index = reader.GetInt32(0);
@@ -140,7 +162,7 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
await sqlConnection.OpenAsync(); await sqlConnection.OpenAsync();
} }
using var command = sqlConnection.CreateCommand(); await using var command = sqlConnection.CreateCommand();
command.Transaction = transaction as SqlTransaction; command.Transaction = transaction as SqlTransaction;
command.CommandText = "DELETE FROM Documents2 WHERE Id = @DocumentId"; command.CommandText = "DELETE FROM Documents2 WHERE Id = @DocumentId";
@@ -160,13 +182,18 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
await sqlConnection.OpenAsync(); await sqlConnection.OpenAsync();
await using var command = sqlConnection.CreateCommand(); await using var command = sqlConnection.CreateCommand();
command.CommandText = "SELECT TOP (@MaxRelevantChunks) Content FROM DocumentChunks2 ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR(1536)))"; command.CommandText = $"""
SELECT TOP (@MaxRelevantChunks) Content
FROM DocumentChunks2
ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR({questionEmbedding.Length})));
""";
command.Parameters.AddWithValue("@MaxRelevantChunks", appSettings.MaxRelevantChunks); command.Parameters.AddWithValue("@MaxRelevantChunks", appSettings.MaxRelevantChunks);
command.Parameters.AddWithValue("@QuestionEmbedding", JsonSerializer.Serialize(questionEmbedding)); command.Parameters.AddWithValue("@QuestionEmbedding", JsonSerializer.Serialize(questionEmbedding));
var chunks = new List<string>(); var chunks = new List<string>();
using var reader = await command.ExecuteReaderAsync(); await using var reader = await command.ExecuteReaderAsync();
while (await reader.ReadAsync()) while (await reader.ReadAsync())
{ {
var content = reader.GetString(0); var content = reader.GetString(0);
@@ -4,7 +4,7 @@ public class AzureOpenAISettings
{ {
public required ServiceSettings ChatCompletion { get; init; } public required ServiceSettings ChatCompletion { get; init; }
public required ServiceSettings Embedding { get; init; } public required EmbeddingServiceSettings Embedding { get; init; }
} }
public class ServiceSettings public class ServiceSettings
@@ -15,3 +15,8 @@ public class ServiceSettings
public required string ApiKey { get; init; } public required string ApiKey { get; init; }
} }
public class EmbeddingServiceSettings : ServiceSettings
{
public int? Dimensions { get; set; }
}
+2 -1
View File
@@ -11,7 +11,8 @@
"Embedding": { "Embedding": {
"Endpoint": "", "Endpoint": "",
"Deployment": "", "Deployment": "",
"ApiKey": "" "ApiKey": "",
"Dimensions": null
} }
}, },
"AppSettings": { "AppSettings": {