From 3e952514852c659400417bc40688b076c649c128 Mon Sep 17 00:00:00 2001 From: Marco Minerva Date: Mon, 30 Sep 2024 17:08:28 +0200 Subject: [PATCH] Refactor to use the native VECTOR type --- Scripts.sql | 20 ++- .../DataAccessLayer/ApplicationDbContext.cs | 52 ------- .../DataAccessLayer/Entities/Document.cs | 12 -- .../DataAccessLayer/Entities/DocumentChunk.cs | 16 -- SqlDatabaseVectorSearch/Program.cs | 10 +- .../Services/VectorSearchService.cs | 142 ++++++++++++++---- .../SqlDatabaseVectorSearch.csproj | 3 +- 7 files changed, 129 insertions(+), 126 deletions(-) delete mode 100644 SqlDatabaseVectorSearch/DataAccessLayer/ApplicationDbContext.cs delete mode 100644 SqlDatabaseVectorSearch/DataAccessLayer/Entities/Document.cs delete mode 100644 SqlDatabaseVectorSearch/DataAccessLayer/Entities/DocumentChunk.cs diff --git a/Scripts.sql b/Scripts.sql index bb7843f..d62c420 100644 --- a/Scripts.sql +++ b/Scripts.sql @@ -1,26 +1,32 @@ -CREATE TABLE [dbo].[DocumentChunks]( +CREATE TABLE [dbo].[DocumentChunks2]( [Id] [uniqueidentifier] NOT NULL, [DocumentId] [uniqueidentifier] NOT NULL, [Index] [int] NOT NULL, [Content] [nvarchar](max) NOT NULL, - [Embedding] [varbinary](8000) NOT NULL, - CONSTRAINT [PK_DocumentChunks] PRIMARY KEY CLUSTERED + [Embedding] [vector](1536) NOT NULL, + CONSTRAINT [PK_DocumentChunks2] PRIMARY KEY CLUSTERED ( [Id] ASC )) GO -CREATE TABLE [dbo].[Documents]( +CREATE TABLE [dbo].[Documents2]( [Id] [uniqueidentifier] NOT NULL, [Name] [nvarchar](255) NOT NULL, [CreationDate] [datetimeoffset](7) NOT NULL, - CONSTRAINT [PK_Documents] PRIMARY KEY CLUSTERED + CONSTRAINT [PK_Documents2] PRIMARY KEY CLUSTERED ( [Id] ASC )) GO -ALTER TABLE [dbo].[DocumentChunks] WITH CHECK ADD CONSTRAINT [FK_DocumentChunks_Documents] FOREIGN KEY([DocumentId]) -REFERENCES [dbo].[Documents] ([Id]) +ALTER TABLE [dbo].[DocumentChunks2] WITH CHECK ADD CONSTRAINT [FK_DocumentChunks2_Documents2] FOREIGN KEY([DocumentId]) +REFERENCES [dbo].[Documents2] ([Id]) ON DELETE CASCADE +GO + +ALTER TABLE [dbo].[Documents2] ADD CONSTRAINT [DF_Documents2_Id] DEFAULT (newsequentialid()) FOR [Id] +GO + +ALTER TABLE [dbo].[DocumentChunks2] ADD CONSTRAINT [DF_DocumentChunks2_Id] DEFAULT (newsequentialid()) FOR [Id] GO \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/DataAccessLayer/ApplicationDbContext.cs b/SqlDatabaseVectorSearch/DataAccessLayer/ApplicationDbContext.cs deleted file mode 100644 index f4d88ed..0000000 --- a/SqlDatabaseVectorSearch/DataAccessLayer/ApplicationDbContext.cs +++ /dev/null @@ -1,52 +0,0 @@ -using EntityFramework.Exceptions.SqlServer; -using Microsoft.EntityFrameworkCore; -using SqlDatabaseVectorSearch.DataAccessLayer.Entities; - -namespace SqlDatabaseVectorSearch.DataAccessLayer; - -public class ApplicationDbContext(DbContextOptions options) : DbContext(options) -{ - public virtual DbSet Documents { get; set; } - - public virtual DbSet DocumentChunks { get; set; } - - protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) - { - base.OnConfiguring(optionsBuilder); - - optionsBuilder.UseExceptionProcessor(); - //optionsBuilder.EnableSensitiveDataLogging(); - } - - protected override void OnModelCreating(ModelBuilder modelBuilder) - { - modelBuilder.Entity(entity => - { - entity.ToTable("Documents"); - entity.HasKey(e => e.Id); - - entity.Property(e => e.Id).ValueGeneratedOnAdd(); - entity.Property(e => e.Name) - .IsRequired() - .HasMaxLength(255); - }); - - modelBuilder.Entity(entity => - { - entity.ToTable("DocumentChunks"); - entity.HasKey(e => e.Id); - - entity.Property(e => e.Id).ValueGeneratedOnAdd(); - entity.Property(e => e.Content).IsRequired(); - entity.Property(e => e.Embedding) - .IsRequired() - .HasMaxLength(8000) - .IsVector(); - - entity.HasOne(d => d.Document).WithMany(p => p.Chunks) - .HasForeignKey(d => d.DocumentId) - .OnDelete(DeleteBehavior.Cascade) - .HasConstraintName("FK_DocumentChunks_Documents"); - }); - } -} diff --git a/SqlDatabaseVectorSearch/DataAccessLayer/Entities/Document.cs b/SqlDatabaseVectorSearch/DataAccessLayer/Entities/Document.cs deleted file mode 100644 index 17f4b1a..0000000 --- a/SqlDatabaseVectorSearch/DataAccessLayer/Entities/Document.cs +++ /dev/null @@ -1,12 +0,0 @@ -namespace SqlDatabaseVectorSearch.DataAccessLayer.Entities; - -public class Document -{ - public Guid Id { get; set; } - - public required string Name { get; set; } - - public DateTimeOffset CreationDate { get; set; } - - public virtual ICollection Chunks { get; set; } = []; -} diff --git a/SqlDatabaseVectorSearch/DataAccessLayer/Entities/DocumentChunk.cs b/SqlDatabaseVectorSearch/DataAccessLayer/Entities/DocumentChunk.cs deleted file mode 100644 index ea21ebc..0000000 --- a/SqlDatabaseVectorSearch/DataAccessLayer/Entities/DocumentChunk.cs +++ /dev/null @@ -1,16 +0,0 @@ -namespace SqlDatabaseVectorSearch.DataAccessLayer.Entities; - -public class DocumentChunk -{ - public Guid Id { get; set; } - - public Guid DocumentId { get; set; } - - public int Index { get; set; } - - public required string Content { get; set; } - - public required float[] Embedding { get; set; } - - public virtual Document Document { get; set; } = null!; -} diff --git a/SqlDatabaseVectorSearch/Program.cs b/SqlDatabaseVectorSearch/Program.cs index df29d6f..8832861 100644 --- a/SqlDatabaseVectorSearch/Program.cs +++ b/SqlDatabaseVectorSearch/Program.cs @@ -1,9 +1,9 @@ using Microsoft.AspNetCore.Http.HttpResults; +using Microsoft.Data.SqlClient; using Microsoft.EntityFrameworkCore; using Microsoft.OpenApi.Models; using Microsoft.SemanticKernel; using MinimalHelpers.OpenApi; -using SqlDatabaseVectorSearch.DataAccessLayer; using SqlDatabaseVectorSearch.Models; using SqlDatabaseVectorSearch.Services; using SqlDatabaseVectorSearch.Settings; @@ -19,12 +19,10 @@ var appSettings = builder.Services.ConfigureAndGet(builder.Configur builder.Services.AddSingleton(TimeProvider.System); -builder.Services.AddSqlServer(builder.Configuration.GetConnectionString("SqlConnection"), options => +builder.Services.AddScoped(_ => { - options.UseVectorSearch(); -}, options => -{ - options.UseQueryTrackingBehavior(QueryTrackingBehavior.NoTracking); + var sqlConnection = new SqlConnection(builder.Configuration.GetConnectionString("SqlConnection")); + return sqlConnection; }); builder.Services.AddMemoryCache(); diff --git a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs index d7e10ce..013a60b 100644 --- a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs +++ b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs @@ -1,18 +1,19 @@ -using System.Text; -using Microsoft.EntityFrameworkCore; +using System.Data; +using System.Data.Common; +using System.Text; +using System.Text.Json; +using Microsoft.Data.SqlClient; using Microsoft.Extensions.Options; using Microsoft.SemanticKernel.Embeddings; using Microsoft.SemanticKernel.Text; -using SqlDatabaseVectorSearch.DataAccessLayer; using SqlDatabaseVectorSearch.Models; using SqlDatabaseVectorSearch.Settings; using UglyToad.PdfPig; using UglyToad.PdfPig.DocumentLayoutAnalysis.TextExtractor; -using Entities = SqlDatabaseVectorSearch.DataAccessLayer.Entities; namespace SqlDatabaseVectorSearch.Services; -public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingGenerationService textEmbeddingGenerationService, ChatService chatService, TimeProvider timeProvider, IOptions appSettingsOptions) +public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGenerationService textEmbeddingGenerationService, ChatService chatService, TimeProvider timeProvider, IOptions appSettingsOptions) { private readonly AppSettings appSettings = appSettingsOptions.Value; @@ -21,16 +22,25 @@ public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingG // Extract the contents of the file (currently, only PDF files are supported). var content = await GetContentAsync(stream); - await dbContext.Database.BeginTransactionAsync(); + await sqlConnection.OpenAsync(); + await using var transaction = await sqlConnection.BeginTransactionAsync(); if (documentId.HasValue) { // If the user is importing a document that already exists, delete the previous one. - await DeleteDocumentAsync(documentId.Value); + await DeleteDocumentAsync(documentId.Value, transaction); } - var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() }; - dbContext.Documents.Add(document); + await using var command = sqlConnection.CreateCommand(); + command.Transaction = (SqlTransaction)transaction; + + command.CommandText = "INSERT INTO Documents2 (Id, [Name], CreationDate) OUTPUT INSERTED.Id VALUES (@Id, @Name, @CreationDate)"; + command.Parameters.AddWithValue("@Id", documentId.GetValueOrDefault(Guid.NewGuid())); + command.Parameters.AddWithValue("@Name", name); + command.Parameters.AddWithValue("@CreationDate", timeProvider.GetUtcNow()); + + var insertedId = await command.ExecuteScalarAsync(); + documentId = (Guid)insertedId!; // Split the content into chunks and generate the embeddings for each one. var paragraphs = TextChunker.SplitPlainTextParagraphs(TextChunker.SplitPlainTextLines(content, appSettings.MaxTokensPerLine), appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens); @@ -39,45 +49,105 @@ public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingG var index = 0; foreach (var (paragraph, embedding) in paragraphs.Zip(embeddings, (p, e) => (p, e.ToArray()))) { - var documentChunk = new Entities.DocumentChunk { Document = document, Index = index++, Content = paragraph, Embedding = embedding }; - dbContext.DocumentChunks.Add(documentChunk); + command.Parameters.Clear(); + + command.CommandText = "INSERT INTO DocumentChunks2 (DocumentId, [Index], Content, Embedding) VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR(1536)))"; + command.Parameters.AddWithValue("@DocumentId", documentId); + command.Parameters.AddWithValue("@Index", index++); + command.Parameters.AddWithValue("@Content", paragraph); + command.Parameters.AddWithValue("@Embedding", JsonSerializer.Serialize(embedding)); + + await command.ExecuteNonQueryAsync(); } - await dbContext.SaveChangesAsync(); - await dbContext.Database.CommitTransactionAsync(); + await transaction.CommitAsync(); - return document.Id; + return documentId.Value; } public async Task> GetDocumentsAsync() { - var documents = await dbContext.Documents.OrderBy(d => d.Name) - .Select(d => new Document(d.Id, d.Name, d.CreationDate, d.Chunks.Count)) - .ToListAsync(); + await sqlConnection.OpenAsync(); + await using var command = sqlConnection.CreateCommand(); + + command.CommandText = "SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks2 WHERE DocumentId = Documents2.Id) FROM Documents2 ORDER BY [Name]"; + + var documents = new List(); + + using var reader = await command.ExecuteReaderAsync(); + while (await reader.ReadAsync()) + { + var id = reader.GetGuid(0); + var name = reader.GetString(1); + var creationDate = reader.GetDateTimeOffset(2); + var chunkCount = reader.GetInt32(3); + + documents.Add(new(id, name, creationDate, chunkCount)); + } return documents; } public async Task> GetDocumentChunksAsync(Guid documentId) { - var documentChunks = await dbContext.DocumentChunks.Where(c => c.DocumentId == documentId).OrderBy(c => c.Index) - .Select(c => new DocumentChunk(c.Id, c.Index, c.Content, null)) - .ToListAsync(); + await sqlConnection.OpenAsync(); + await using var command = sqlConnection.CreateCommand(); + + command.CommandText = "SELECT Id, [Index], Content FROM DocumentChunks2 WHERE DocumentId = @DocumentId ORDER BY [Index]"; + command.Parameters.AddWithValue("@DocumentId", documentId); + + var documentChunks = new List(); + + using var reader = await command.ExecuteReaderAsync(); + while (await reader.ReadAsync()) + { + var id = reader.GetGuid(0); + var index = reader.GetInt32(1); + var content = reader.GetString(2); + + documentChunks.Add(new(id, index, content, null)); + } return documentChunks; } public async Task GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId) { - var documentChunk = await dbContext.DocumentChunks.Where(c => c.Id == documentChunkId && c.DocumentId == documentId) - .Select(c => new DocumentChunk(c.Id, c.Index, c.Content, c.Embedding)) - .FirstOrDefaultAsync(); + await sqlConnection.OpenAsync(); + await using var command = sqlConnection.CreateCommand(); - return documentChunk; + command.CommandText = "SELECT [Index], Content, CAST(Embedding AS NVARCHAR(MAX)) FROM DocumentChunks2 WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId"; + command.Parameters.AddWithValue("@DocumentChunkId", documentChunkId); + command.Parameters.AddWithValue("@DocumentId", documentId); + + using var reader = await command.ExecuteReaderAsync(); + if (reader.HasRows && await reader.ReadAsync()) + { + var index = reader.GetInt32(0); + var content = reader.GetString(1); + var embedding = JsonSerializer.Deserialize(reader.GetString(2))!; + + return new(documentChunkId, index, content, embedding); + } + + return null; } - public Task DeleteDocumentAsync(Guid documentId) - => dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync(); + public async Task DeleteDocumentAsync(Guid documentId, DbTransaction? transaction = null) + { + if (sqlConnection.State == ConnectionState.Closed) + { + await sqlConnection.OpenAsync(); + } + + using var command = sqlConnection.CreateCommand(); + command.Transaction = transaction as SqlTransaction; + + command.CommandText = "DELETE FROM Documents2 WHERE Id = @DocumentId"; + command.Parameters.AddWithValue("@DocumentId", documentId); + + await command.ExecuteNonQueryAsync(); + } public async Task AskQuestionAsync(Question question, bool reformulate = true) { @@ -87,11 +157,21 @@ public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingG // Perform Vector Search on SQL Database. var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion); - var chunks = await dbContext.DocumentChunks - .OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray())) - .Select(c => c.Content) - .Take(appSettings.MaxRelevantChunks) - .ToListAsync(); + await sqlConnection.OpenAsync(); + await using var command = sqlConnection.CreateCommand(); + + command.CommandText = "SELECT TOP (@MaxRelevantChunks) Content FROM DocumentChunks2 ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR(1536)))"; + command.Parameters.AddWithValue("@MaxRelevantChunks", appSettings.MaxRelevantChunks); + command.Parameters.AddWithValue("@QuestionEmbedding", JsonSerializer.Serialize(questionEmbedding)); + + var chunks = new List(); + + using var reader = await command.ExecuteReaderAsync(); + while (await reader.ReadAsync()) + { + var content = reader.GetString(0); + chunks.Add(content); + } var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion); return new Response(reformulatedQuestion, answer); diff --git a/SqlDatabaseVectorSearch/SqlDatabaseVectorSearch.csproj b/SqlDatabaseVectorSearch/SqlDatabaseVectorSearch.csproj index a5caf99..3e07174 100644 --- a/SqlDatabaseVectorSearch/SqlDatabaseVectorSearch.csproj +++ b/SqlDatabaseVectorSearch/SqlDatabaseVectorSearch.csproj @@ -8,10 +8,9 @@ - - +