Add dimensions parameter for embeddings; reformat SQL

Updated Program.cs to include dimensions parameter for AddAzureOpenAITextEmbeddingGeneration sourced from aiSettings.Embedding.Dimensions. Reformatted SQL command texts in VectorSearchService.cs for better readability. Introduced EmbeddingServiceSettings class in AzureOpenAISettings.cs to allow optional dimensions configuration. Updated appsettings.json to include new Dimensions property under Embedding section.
This commit is contained in:
Marco Minerva
2024-09-30 17:53:59 +02:00
parent 3e95251485
commit 2dff0aae55
4 changed files with 47 additions and 14 deletions
+1 -1
View File
@@ -30,7 +30,7 @@ builder.Services.AddMemoryCache();
// Semantic Kernel is used to generate embeddings and to reformulate questions taking into account all the previous interactions,
// so that embeddings themselves can be generated more accurately.
builder.Services.AddKernel()
.AddAzureOpenAITextEmbeddingGeneration(aiSettings.Embedding.Deployment, aiSettings.Embedding.Endpoint, aiSettings.Embedding.ApiKey)
.AddAzureOpenAITextEmbeddingGeneration(aiSettings.Embedding.Deployment, aiSettings.Embedding.Endpoint, aiSettings.Embedding.ApiKey, dimensions: aiSettings.Embedding.Dimensions)
.AddAzureOpenAIChatCompletion(aiSettings.ChatCompletion.Deployment, aiSettings.ChatCompletion.Endpoint, aiSettings.ChatCompletion.ApiKey);
builder.Services.AddScoped<ChatService>();
@@ -34,7 +34,12 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
await using var command = sqlConnection.CreateCommand();
command.Transaction = (SqlTransaction)transaction;
command.CommandText = "INSERT INTO Documents2 (Id, [Name], CreationDate) OUTPUT INSERTED.Id VALUES (@Id, @Name, @CreationDate)";
command.CommandText = """
INSERT INTO Documents2 (Id, [Name], CreationDate)
OUTPUT INSERTED.Id
VALUES (@Id, @Name, @CreationDate);
""";
command.Parameters.AddWithValue("@Id", documentId.GetValueOrDefault(Guid.NewGuid()));
command.Parameters.AddWithValue("@Name", name);
command.Parameters.AddWithValue("@CreationDate", timeProvider.GetUtcNow());
@@ -51,7 +56,11 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
{
command.Parameters.Clear();
command.CommandText = "INSERT INTO DocumentChunks2 (DocumentId, [Index], Content, Embedding) VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR(1536)))";
command.CommandText = $"""
INSERT INTO DocumentChunks2 (DocumentId, [Index], Content, Embedding)
VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR({embedding.Length})));
""";
command.Parameters.AddWithValue("@DocumentId", documentId);
command.Parameters.AddWithValue("@Index", index++);
command.Parameters.AddWithValue("@Content", paragraph);
@@ -70,11 +79,15 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
await sqlConnection.OpenAsync();
await using var command = sqlConnection.CreateCommand();
command.CommandText = "SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks2 WHERE DocumentId = Documents2.Id) FROM Documents2 ORDER BY [Name]";
command.CommandText = """
SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks2
WHERE DocumentId = Documents2.Id)
FROM Documents2 ORDER BY [Name];
""";
var documents = new List<Document>();
using var reader = await command.ExecuteReaderAsync();
await using var reader = await command.ExecuteReaderAsync();
while (await reader.ReadAsync())
{
var id = reader.GetGuid(0);
@@ -93,12 +106,17 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
await sqlConnection.OpenAsync();
await using var command = sqlConnection.CreateCommand();
command.CommandText = "SELECT Id, [Index], Content FROM DocumentChunks2 WHERE DocumentId = @DocumentId ORDER BY [Index]";
command.CommandText = """
SELECT Id, [Index], Content
FROM DocumentChunks2 WHERE DocumentId = @DocumentId
ORDER BY [Index];
""";
command.Parameters.AddWithValue("@DocumentId", documentId);
var documentChunks = new List<DocumentChunk>();
using var reader = await command.ExecuteReaderAsync();
await using var reader = await command.ExecuteReaderAsync();
while (await reader.ReadAsync())
{
var id = reader.GetGuid(0);
@@ -116,11 +134,15 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
await sqlConnection.OpenAsync();
await using var command = sqlConnection.CreateCommand();
command.CommandText = "SELECT [Index], Content, CAST(Embedding AS NVARCHAR(MAX)) FROM DocumentChunks2 WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId";
command.CommandText = """
"SELECT [Index], Content, CAST(Embedding AS NVARCHAR(MAX))
FROM DocumentChunks2 WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId;
""";
command.Parameters.AddWithValue("@DocumentChunkId", documentChunkId);
command.Parameters.AddWithValue("@DocumentId", documentId);
using var reader = await command.ExecuteReaderAsync();
await using var reader = await command.ExecuteReaderAsync();
if (reader.HasRows && await reader.ReadAsync())
{
var index = reader.GetInt32(0);
@@ -140,7 +162,7 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
await sqlConnection.OpenAsync();
}
using var command = sqlConnection.CreateCommand();
await using var command = sqlConnection.CreateCommand();
command.Transaction = transaction as SqlTransaction;
command.CommandText = "DELETE FROM Documents2 WHERE Id = @DocumentId";
@@ -160,13 +182,18 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
await sqlConnection.OpenAsync();
await using var command = sqlConnection.CreateCommand();
command.CommandText = "SELECT TOP (@MaxRelevantChunks) Content FROM DocumentChunks2 ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR(1536)))";
command.CommandText = $"""
SELECT TOP (@MaxRelevantChunks) Content
FROM DocumentChunks2
ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR({questionEmbedding.Length})));
""";
command.Parameters.AddWithValue("@MaxRelevantChunks", appSettings.MaxRelevantChunks);
command.Parameters.AddWithValue("@QuestionEmbedding", JsonSerializer.Serialize(questionEmbedding));
var chunks = new List<string>();
using var reader = await command.ExecuteReaderAsync();
await using var reader = await command.ExecuteReaderAsync();
while (await reader.ReadAsync())
{
var content = reader.GetString(0);
@@ -4,7 +4,7 @@ public class AzureOpenAISettings
{
public required ServiceSettings ChatCompletion { get; init; }
public required ServiceSettings Embedding { get; init; }
public required EmbeddingServiceSettings Embedding { get; init; }
}
public class ServiceSettings
@@ -15,3 +15,8 @@ public class ServiceSettings
public required string ApiKey { get; init; }
}
public class EmbeddingServiceSettings : ServiceSettings
{
public int? Dimensions { get; set; }
}
+2 -1
View File
@@ -11,7 +11,8 @@
"Embedding": {
"Endpoint": "",
"Deployment": "",
"ApiKey": ""
"ApiKey": "",
"Dimensions": null
}
},
"AppSettings": {