Refactor DB operations, rename tables, add Dapper

Refactored `VectorSearchService.cs` to use Dapper for DB operations, replacing raw ADO.NET commands. Updated methods for inserting, retrieving, and deleting documents and chunks. Modified vector search query to use Dapper's `QueryAsync`.

Updated `SqlDatabaseVectorSearch.csproj` to include Dapper package reference, version `2.1.35`.
This commit is contained in:
Marco Minerva
2024-10-01 11:39:21 +02:00
parent 2dff0aae55
commit 4355f72dab
4 changed files with 74 additions and 102 deletions
+15 -15
View File
@@ -1,4 +1,14 @@
CREATE TABLE [dbo].[DocumentChunks2]( CREATE TABLE [dbo].[Documents](
[Id] [uniqueidentifier] NOT NULL,
[Name] [nvarchar](255) NOT NULL,
[CreationDate] [datetimeoffset](7) NOT NULL,
CONSTRAINT [PK_Documents2] PRIMARY KEY CLUSTERED
(
[Id] ASC
))
GO
CREATE TABLE [dbo].[DocumentChunks](
[Id] [uniqueidentifier] NOT NULL, [Id] [uniqueidentifier] NOT NULL,
[DocumentId] [uniqueidentifier] NOT NULL, [DocumentId] [uniqueidentifier] NOT NULL,
[Index] [int] NOT NULL, [Index] [int] NOT NULL,
@@ -10,23 +20,13 @@ CREATE TABLE [dbo].[DocumentChunks2](
)) ))
GO GO
CREATE TABLE [dbo].[Documents2]( ALTER TABLE [dbo].[DocumentChunks] WITH CHECK ADD CONSTRAINT [FK_DocumentChunks_Documents] FOREIGN KEY([DocumentId])
[Id] [uniqueidentifier] NOT NULL, REFERENCES [dbo].[Documents] ([Id])
[Name] [nvarchar](255) NOT NULL,
[CreationDate] [datetimeoffset](7) NOT NULL,
CONSTRAINT [PK_Documents2] PRIMARY KEY CLUSTERED
(
[Id] ASC
))
GO
ALTER TABLE [dbo].[DocumentChunks2] WITH CHECK ADD CONSTRAINT [FK_DocumentChunks2_Documents2] FOREIGN KEY([DocumentId])
REFERENCES [dbo].[Documents2] ([Id])
ON DELETE CASCADE ON DELETE CASCADE
GO GO
ALTER TABLE [dbo].[Documents2] ADD CONSTRAINT [DF_Documents2_Id] DEFAULT (newsequentialid()) FOR [Id] ALTER TABLE [dbo].[Documents] ADD CONSTRAINT [DF_Documents_Id] DEFAULT (newsequentialid()) FOR [Id]
GO GO
ALTER TABLE [dbo].[DocumentChunks2] ADD CONSTRAINT [DF_DocumentChunks2_Id] DEFAULT (newsequentialid()) FOR [Id] ALTER TABLE [dbo].[DocumentChunks] ADD CONSTRAINT [DF_DocumentChunks_Id] DEFAULT (newsequentialid()) FOR [Id]
GO GO
@@ -1,3 +1,14 @@
namespace SqlDatabaseVectorSearch.Models; using System.Text.Json;
public record class DocumentChunk(Guid Id, int Index, string Content, float[]? Embedding = null); namespace SqlDatabaseVectorSearch.Models;
public record class DocumentChunk(Guid Id, int Index, string Content, float[]? Embedding)
{
public DocumentChunk(Guid Id, int Index, string Content) : this(Id, Index, Content, (float[]?)null)
{
}
public DocumentChunk(Guid Id, int Index, string Content, string Embedding) : this(Id, Index, Content, JsonSerializer.Deserialize<float[]?>(Embedding))
{
}
}
@@ -2,12 +2,14 @@
using System.Data.Common; using System.Data.Common;
using System.Text; using System.Text;
using System.Text.Json; using System.Text.Json;
using Dapper;
using Microsoft.Data.SqlClient; using Microsoft.Data.SqlClient;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel.Embeddings; using Microsoft.SemanticKernel.Embeddings;
using Microsoft.SemanticKernel.Text; using Microsoft.SemanticKernel.Text;
using SqlDatabaseVectorSearch.Models; using SqlDatabaseVectorSearch.Models;
using SqlDatabaseVectorSearch.Settings; using SqlDatabaseVectorSearch.Settings;
using TinyHelpers.Extensions;
using UglyToad.PdfPig; using UglyToad.PdfPig;
using UglyToad.PdfPig.DocumentLayoutAnalysis.TextExtractor; using UglyToad.PdfPig.DocumentLayoutAnalysis.TextExtractor;
@@ -35,7 +37,7 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
command.Transaction = (SqlTransaction)transaction; command.Transaction = (SqlTransaction)transaction;
command.CommandText = """ command.CommandText = """
INSERT INTO Documents2 (Id, [Name], CreationDate) INSERT INTO Documents (Id, [Name], CreationDate)
OUTPUT INSERTED.Id OUTPUT INSERTED.Id
VALUES (@Id, @Name, @CreationDate); VALUES (@Id, @Name, @CreationDate);
"""; """;
@@ -51,20 +53,19 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
var paragraphs = TextChunker.SplitPlainTextParagraphs(TextChunker.SplitPlainTextLines(content, appSettings.MaxTokensPerLine), appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens); var paragraphs = TextChunker.SplitPlainTextParagraphs(TextChunker.SplitPlainTextLines(content, appSettings.MaxTokensPerLine), appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens);
var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(paragraphs); var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(paragraphs);
var index = 0; foreach (var (paragraph, index) in paragraphs.WithIndex())
foreach (var (paragraph, embedding) in paragraphs.Zip(embeddings, (p, e) => (p, e.ToArray())))
{ {
command.Parameters.Clear(); command.Parameters.Clear();
command.CommandText = $""" command.CommandText = $"""
INSERT INTO DocumentChunks2 (DocumentId, [Index], Content, Embedding) INSERT INTO DocumentChunks (DocumentId, [Index], Content, Embedding)
VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR({embedding.Length}))); VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR({embeddings[index].Length})));
"""; """;
command.Parameters.AddWithValue("@DocumentId", documentId); command.Parameters.AddWithValue("@DocumentId", documentId);
command.Parameters.AddWithValue("@Index", index++); command.Parameters.AddWithValue("@Index", index);
command.Parameters.AddWithValue("@Content", paragraph); command.Parameters.AddWithValue("@Content", paragraph);
command.Parameters.AddWithValue("@Embedding", JsonSerializer.Serialize(embedding)); command.Parameters.AddWithValue("@Embedding", JsonSerializer.Serialize(embeddings[index]));
await command.ExecuteNonQueryAsync(); await command.ExecuteNonQueryAsync();
} }
@@ -76,83 +77,36 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
public async Task<IEnumerable<Document>> GetDocumentsAsync() public async Task<IEnumerable<Document>> GetDocumentsAsync()
{ {
await sqlConnection.OpenAsync(); var documents = await sqlConnection.QueryAsync<Document>("""
await using var command = sqlConnection.CreateCommand(); SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks WHERE DocumentId = Documents.Id)
FROM Documents
command.CommandText = """ ORDER BY [Name];
SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks2 """);
WHERE DocumentId = Documents2.Id)
FROM Documents2 ORDER BY [Name];
""";
var documents = new List<Document>();
await using var reader = await command.ExecuteReaderAsync();
while (await reader.ReadAsync())
{
var id = reader.GetGuid(0);
var name = reader.GetString(1);
var creationDate = reader.GetDateTimeOffset(2);
var chunkCount = reader.GetInt32(3);
documents.Add(new(id, name, creationDate, chunkCount));
}
return documents; return documents;
} }
public async Task<IEnumerable<DocumentChunk>> GetDocumentChunksAsync(Guid documentId) public async Task<IEnumerable<DocumentChunk>> GetDocumentChunksAsync(Guid documentId)
{ {
await sqlConnection.OpenAsync(); var documentChunks = await sqlConnection.QueryAsync<DocumentChunk>("""
await using var command = sqlConnection.CreateCommand();
command.CommandText = """
SELECT Id, [Index], Content SELECT Id, [Index], Content
FROM DocumentChunks2 WHERE DocumentId = @DocumentId FROM DocumentChunks
WHERE DocumentId = @DocumentId
ORDER BY [Index]; ORDER BY [Index];
"""; """, new { documentId });
command.Parameters.AddWithValue("@DocumentId", documentId);
var documentChunks = new List<DocumentChunk>();
await using var reader = await command.ExecuteReaderAsync();
while (await reader.ReadAsync())
{
var id = reader.GetGuid(0);
var index = reader.GetInt32(1);
var content = reader.GetString(2);
documentChunks.Add(new(id, index, content, null));
}
return documentChunks; return documentChunks;
} }
public async Task<DocumentChunk?> GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId) public async Task<DocumentChunk?> GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId)
{ {
await sqlConnection.OpenAsync(); var documentChunk = await sqlConnection.QueryFirstOrDefaultAsync<DocumentChunk>("""
await using var command = sqlConnection.CreateCommand(); SELECT Id, [Index], Content, CAST(Embedding AS NVARCHAR(MAX)) AS Embedding
FROM DocumentChunks
WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId;
""", new { documentId, documentChunkId });
command.CommandText = """ return documentChunk;
"SELECT [Index], Content, CAST(Embedding AS NVARCHAR(MAX))
FROM DocumentChunks2 WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId;
""";
command.Parameters.AddWithValue("@DocumentChunkId", documentChunkId);
command.Parameters.AddWithValue("@DocumentId", documentId);
await using var reader = await command.ExecuteReaderAsync();
if (reader.HasRows && await reader.ReadAsync())
{
var index = reader.GetInt32(0);
var content = reader.GetString(1);
var embedding = JsonSerializer.Deserialize<float[]>(reader.GetString(2))!;
return new(documentChunkId, index, content, embedding);
}
return null;
} }
public async Task DeleteDocumentAsync(Guid documentId, DbTransaction? transaction = null) public async Task DeleteDocumentAsync(Guid documentId, DbTransaction? transaction = null)
@@ -165,7 +119,7 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
await using var command = sqlConnection.CreateCommand(); await using var command = sqlConnection.CreateCommand();
command.Transaction = transaction as SqlTransaction; command.Transaction = transaction as SqlTransaction;
command.CommandText = "DELETE FROM Documents2 WHERE Id = @DocumentId"; command.CommandText = "DELETE FROM Documents WHERE Id = @DocumentId";
command.Parameters.AddWithValue("@DocumentId", documentId); command.Parameters.AddWithValue("@DocumentId", documentId);
await command.ExecuteNonQueryAsync(); await command.ExecuteNonQueryAsync();
@@ -179,26 +133,32 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
// Perform Vector Search on SQL Database. // Perform Vector Search on SQL Database.
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion); var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion);
await sqlConnection.OpenAsync(); var chunks = await sqlConnection.QueryAsync<string>($"""
await using var command = sqlConnection.CreateCommand();
command.CommandText = $"""
SELECT TOP (@MaxRelevantChunks) Content SELECT TOP (@MaxRelevantChunks) Content
FROM DocumentChunks2 FROM DocumentChunks
ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR({questionEmbedding.Length}))); ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR({questionEmbedding.Length})));
"""; """, new { appSettings.MaxRelevantChunks, QuestionEmbedding = JsonSerializer.Serialize(questionEmbedding) });
command.Parameters.AddWithValue("@MaxRelevantChunks", appSettings.MaxRelevantChunks); //await sqlConnection.OpenAsync();
command.Parameters.AddWithValue("@QuestionEmbedding", JsonSerializer.Serialize(questionEmbedding)); //await using var command = sqlConnection.CreateCommand();
var chunks = new List<string>(); //command.CommandText = $"""
// SELECT TOP (@MaxRelevantChunks) Content
// FROM DocumentChunks
// ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR({questionEmbedding.Length})));
// """;
await using var reader = await command.ExecuteReaderAsync(); //command.Parameters.AddWithValue("@MaxRelevantChunks", appSettings.MaxRelevantChunks);
while (await reader.ReadAsync()) //command.Parameters.AddWithValue("@QuestionEmbedding", JsonSerializer.Serialize(questionEmbedding));
{
var content = reader.GetString(0); //var chunks = new List<string>();
chunks.Add(content);
} //await using var reader = await command.ExecuteReaderAsync();
//while (await reader.ReadAsync())
//{
// var content = reader.GetString(0);
// chunks.Add(content);
//}
var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion); var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion);
return new Response(reformulatedQuestion, answer); return new Response(reformulatedQuestion, answer);
@@ -8,6 +8,7 @@
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="Dapper" Version="2.1.35" />
<PackageReference Include="EntityFrameworkCore.Exceptions.SqlServer" Version="8.1.3" /> <PackageReference Include="EntityFrameworkCore.Exceptions.SqlServer" Version="8.1.3" />
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="8.0.8" /> <PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="8.0.8" />
<PackageReference Include="Microsoft.Data.SqlClient" Version="5.2.2" /> <PackageReference Include="Microsoft.Data.SqlClient" Version="5.2.2" />