diff --git a/Scripts.sql b/Scripts.sql index d62c420..a2e1e93 100644 --- a/Scripts.sql +++ b/Scripts.sql @@ -1,4 +1,14 @@ -CREATE TABLE [dbo].[DocumentChunks2]( +CREATE TABLE [dbo].[Documents]( + [Id] [uniqueidentifier] NOT NULL, + [Name] [nvarchar](255) NOT NULL, + [CreationDate] [datetimeoffset](7) NOT NULL, + CONSTRAINT [PK_Documents2] PRIMARY KEY CLUSTERED +( + [Id] ASC +)) +GO + +CREATE TABLE [dbo].[DocumentChunks]( [Id] [uniqueidentifier] NOT NULL, [DocumentId] [uniqueidentifier] NOT NULL, [Index] [int] NOT NULL, @@ -10,23 +20,13 @@ CREATE TABLE [dbo].[DocumentChunks2]( )) GO -CREATE TABLE [dbo].[Documents2]( - [Id] [uniqueidentifier] NOT NULL, - [Name] [nvarchar](255) NOT NULL, - [CreationDate] [datetimeoffset](7) NOT NULL, - CONSTRAINT [PK_Documents2] PRIMARY KEY CLUSTERED -( - [Id] ASC -)) -GO - -ALTER TABLE [dbo].[DocumentChunks2] WITH CHECK ADD CONSTRAINT [FK_DocumentChunks2_Documents2] FOREIGN KEY([DocumentId]) -REFERENCES [dbo].[Documents2] ([Id]) +ALTER TABLE [dbo].[DocumentChunks] WITH CHECK ADD CONSTRAINT [FK_DocumentChunks_Documents] FOREIGN KEY([DocumentId]) +REFERENCES [dbo].[Documents] ([Id]) ON DELETE CASCADE GO -ALTER TABLE [dbo].[Documents2] ADD CONSTRAINT [DF_Documents2_Id] DEFAULT (newsequentialid()) FOR [Id] +ALTER TABLE [dbo].[Documents] ADD CONSTRAINT [DF_Documents_Id] DEFAULT (newsequentialid()) FOR [Id] GO -ALTER TABLE [dbo].[DocumentChunks2] ADD CONSTRAINT [DF_DocumentChunks2_Id] DEFAULT (newsequentialid()) FOR [Id] +ALTER TABLE [dbo].[DocumentChunks] ADD CONSTRAINT [DF_DocumentChunks_Id] DEFAULT (newsequentialid()) FOR [Id] GO \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Models/DocumentChunk.cs b/SqlDatabaseVectorSearch/Models/DocumentChunk.cs index 3753352..15bfab7 100644 --- a/SqlDatabaseVectorSearch/Models/DocumentChunk.cs +++ b/SqlDatabaseVectorSearch/Models/DocumentChunk.cs @@ -1,3 +1,14 @@ -namespace SqlDatabaseVectorSearch.Models; +using System.Text.Json; -public record class DocumentChunk(Guid Id, int Index, string Content, float[]? Embedding = null); +namespace SqlDatabaseVectorSearch.Models; + +public record class DocumentChunk(Guid Id, int Index, string Content, float[]? Embedding) +{ + public DocumentChunk(Guid Id, int Index, string Content) : this(Id, Index, Content, (float[]?)null) + { + } + + public DocumentChunk(Guid Id, int Index, string Content, string Embedding) : this(Id, Index, Content, JsonSerializer.Deserialize(Embedding)) + { + } +} diff --git a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs index 207689d..f1b57e0 100644 --- a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs +++ b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs @@ -2,12 +2,14 @@ using System.Data.Common; using System.Text; using System.Text.Json; +using Dapper; using Microsoft.Data.SqlClient; using Microsoft.Extensions.Options; using Microsoft.SemanticKernel.Embeddings; using Microsoft.SemanticKernel.Text; using SqlDatabaseVectorSearch.Models; using SqlDatabaseVectorSearch.Settings; +using TinyHelpers.Extensions; using UglyToad.PdfPig; using UglyToad.PdfPig.DocumentLayoutAnalysis.TextExtractor; @@ -35,7 +37,7 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene command.Transaction = (SqlTransaction)transaction; command.CommandText = """ - INSERT INTO Documents2 (Id, [Name], CreationDate) + INSERT INTO Documents (Id, [Name], CreationDate) OUTPUT INSERTED.Id VALUES (@Id, @Name, @CreationDate); """; @@ -51,20 +53,19 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene var paragraphs = TextChunker.SplitPlainTextParagraphs(TextChunker.SplitPlainTextLines(content, appSettings.MaxTokensPerLine), appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens); var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(paragraphs); - var index = 0; - foreach (var (paragraph, embedding) in paragraphs.Zip(embeddings, (p, e) => (p, e.ToArray()))) + foreach (var (paragraph, index) in paragraphs.WithIndex()) { command.Parameters.Clear(); command.CommandText = $""" - INSERT INTO DocumentChunks2 (DocumentId, [Index], Content, Embedding) - VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR({embedding.Length}))); + INSERT INTO DocumentChunks (DocumentId, [Index], Content, Embedding) + VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR({embeddings[index].Length}))); """; command.Parameters.AddWithValue("@DocumentId", documentId); - command.Parameters.AddWithValue("@Index", index++); + command.Parameters.AddWithValue("@Index", index); command.Parameters.AddWithValue("@Content", paragraph); - command.Parameters.AddWithValue("@Embedding", JsonSerializer.Serialize(embedding)); + command.Parameters.AddWithValue("@Embedding", JsonSerializer.Serialize(embeddings[index])); await command.ExecuteNonQueryAsync(); } @@ -76,83 +77,36 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene public async Task> GetDocumentsAsync() { - await sqlConnection.OpenAsync(); - await using var command = sqlConnection.CreateCommand(); - - command.CommandText = """ - SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks2 - WHERE DocumentId = Documents2.Id) - FROM Documents2 ORDER BY [Name]; - """; - - var documents = new List(); - - await using var reader = await command.ExecuteReaderAsync(); - while (await reader.ReadAsync()) - { - var id = reader.GetGuid(0); - var name = reader.GetString(1); - var creationDate = reader.GetDateTimeOffset(2); - var chunkCount = reader.GetInt32(3); - - documents.Add(new(id, name, creationDate, chunkCount)); - } + var documents = await sqlConnection.QueryAsync(""" + SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks WHERE DocumentId = Documents.Id) + FROM Documents + ORDER BY [Name]; + """); return documents; } public async Task> GetDocumentChunksAsync(Guid documentId) { - await sqlConnection.OpenAsync(); - await using var command = sqlConnection.CreateCommand(); - - command.CommandText = """ + var documentChunks = await sqlConnection.QueryAsync(""" SELECT Id, [Index], Content - FROM DocumentChunks2 WHERE DocumentId = @DocumentId + FROM DocumentChunks + WHERE DocumentId = @DocumentId ORDER BY [Index]; - """; - - command.Parameters.AddWithValue("@DocumentId", documentId); - - var documentChunks = new List(); - - await using var reader = await command.ExecuteReaderAsync(); - while (await reader.ReadAsync()) - { - var id = reader.GetGuid(0); - var index = reader.GetInt32(1); - var content = reader.GetString(2); - - documentChunks.Add(new(id, index, content, null)); - } + """, new { documentId }); return documentChunks; } public async Task GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId) { - await sqlConnection.OpenAsync(); - await using var command = sqlConnection.CreateCommand(); + var documentChunk = await sqlConnection.QueryFirstOrDefaultAsync(""" + SELECT Id, [Index], Content, CAST(Embedding AS NVARCHAR(MAX)) AS Embedding + FROM DocumentChunks + WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId; + """, new { documentId, documentChunkId }); - command.CommandText = """ - "SELECT [Index], Content, CAST(Embedding AS NVARCHAR(MAX)) - FROM DocumentChunks2 WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId; - """; - - command.Parameters.AddWithValue("@DocumentChunkId", documentChunkId); - command.Parameters.AddWithValue("@DocumentId", documentId); - - await using var reader = await command.ExecuteReaderAsync(); - if (reader.HasRows && await reader.ReadAsync()) - { - var index = reader.GetInt32(0); - var content = reader.GetString(1); - var embedding = JsonSerializer.Deserialize(reader.GetString(2))!; - - return new(documentChunkId, index, content, embedding); - } - - return null; + return documentChunk; } public async Task DeleteDocumentAsync(Guid documentId, DbTransaction? transaction = null) @@ -165,7 +119,7 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene await using var command = sqlConnection.CreateCommand(); command.Transaction = transaction as SqlTransaction; - command.CommandText = "DELETE FROM Documents2 WHERE Id = @DocumentId"; + command.CommandText = "DELETE FROM Documents WHERE Id = @DocumentId"; command.Parameters.AddWithValue("@DocumentId", documentId); await command.ExecuteNonQueryAsync(); @@ -179,26 +133,32 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene // Perform Vector Search on SQL Database. var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion); - await sqlConnection.OpenAsync(); - await using var command = sqlConnection.CreateCommand(); - - command.CommandText = $""" + var chunks = await sqlConnection.QueryAsync($""" SELECT TOP (@MaxRelevantChunks) Content - FROM DocumentChunks2 + FROM DocumentChunks ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR({questionEmbedding.Length}))); - """; + """, new { appSettings.MaxRelevantChunks, QuestionEmbedding = JsonSerializer.Serialize(questionEmbedding) }); - command.Parameters.AddWithValue("@MaxRelevantChunks", appSettings.MaxRelevantChunks); - command.Parameters.AddWithValue("@QuestionEmbedding", JsonSerializer.Serialize(questionEmbedding)); + //await sqlConnection.OpenAsync(); + //await using var command = sqlConnection.CreateCommand(); - var chunks = new List(); + //command.CommandText = $""" + // SELECT TOP (@MaxRelevantChunks) Content + // FROM DocumentChunks + // ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR({questionEmbedding.Length}))); + // """; - await using var reader = await command.ExecuteReaderAsync(); - while (await reader.ReadAsync()) - { - var content = reader.GetString(0); - chunks.Add(content); - } + //command.Parameters.AddWithValue("@MaxRelevantChunks", appSettings.MaxRelevantChunks); + //command.Parameters.AddWithValue("@QuestionEmbedding", JsonSerializer.Serialize(questionEmbedding)); + + //var chunks = new List(); + + //await using var reader = await command.ExecuteReaderAsync(); + //while (await reader.ReadAsync()) + //{ + // var content = reader.GetString(0); + // chunks.Add(content); + //} var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion); return new Response(reformulatedQuestion, answer); diff --git a/SqlDatabaseVectorSearch/SqlDatabaseVectorSearch.csproj b/SqlDatabaseVectorSearch/SqlDatabaseVectorSearch.csproj index 3e07174..16d541a 100644 --- a/SqlDatabaseVectorSearch/SqlDatabaseVectorSearch.csproj +++ b/SqlDatabaseVectorSearch/SqlDatabaseVectorSearch.csproj @@ -8,6 +8,7 @@ +