diff --git a/SqlDatabaseVectorSearch/DataAccessLayer/ApplicationDbContext.cs b/SqlDatabaseVectorSearch/DataAccessLayer/ApplicationDbContext.cs new file mode 100644 index 0000000..780b9e7 --- /dev/null +++ b/SqlDatabaseVectorSearch/DataAccessLayer/ApplicationDbContext.cs @@ -0,0 +1,51 @@ +using EntityFramework.Exceptions.SqlServer; +using Microsoft.EntityFrameworkCore; +using SqlDatabaseVectorSearch.DataAccessLayer.Entities; + +namespace SqlDatabaseVectorSearch.DataAccessLayer; + +public class ApplicationDbContext(DbContextOptions options) : DbContext(options) +{ + public virtual DbSet Documents { get; set; } + + public virtual DbSet DocumentChunks { get; set; } + + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) + { + base.OnConfiguring(optionsBuilder); + + optionsBuilder.UseExceptionProcessor(); + //optionsBuilder.EnableSensitiveDataLogging(); + } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity(entity => + { + entity.ToTable("Documents"); + entity.HasKey(e => e.Id); + + entity.Property(e => e.Id).ValueGeneratedOnAdd(); + entity.Property(e => e.Name) + .IsRequired() + .HasMaxLength(255); + }); + + modelBuilder.Entity(entity => + { + entity.ToTable("DocumentChunks"); + entity.HasKey(e => e.Id); + + entity.Property(e => e.Id).ValueGeneratedOnAdd(); + entity.Property(e => e.Content).IsRequired(); + entity.Property(e => e.Embedding) + .IsRequired() + .HasColumnType("vector(1536)"); + + entity.HasOne(d => d.Document).WithMany(p => p.Chunks) + .HasForeignKey(d => d.DocumentId) + .OnDelete(DeleteBehavior.Cascade) + .HasConstraintName("FK_DocumentChunks_Documents"); + }); + } +} \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/DataAccessLayer/Entities/Document.cs b/SqlDatabaseVectorSearch/DataAccessLayer/Entities/Document.cs new file mode 100644 index 0000000..d90cf56 --- /dev/null +++ b/SqlDatabaseVectorSearch/DataAccessLayer/Entities/Document.cs @@ -0,0 +1,12 @@ +namespace SqlDatabaseVectorSearch.DataAccessLayer.Entities; + +public class Document +{ + public Guid Id { get; set; } + + public required string Name { get; set; } + + public DateTimeOffset CreationDate { get; set; } + + public virtual ICollection Chunks { get; set; } = []; +} \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/DataAccessLayer/Entities/DocumentChunk.cs b/SqlDatabaseVectorSearch/DataAccessLayer/Entities/DocumentChunk.cs new file mode 100644 index 0000000..0d1886c --- /dev/null +++ b/SqlDatabaseVectorSearch/DataAccessLayer/Entities/DocumentChunk.cs @@ -0,0 +1,16 @@ +namespace SqlDatabaseVectorSearch.DataAccessLayer.Entities; + +public class DocumentChunk +{ + public Guid Id { get; set; } + + public Guid DocumentId { get; set; } + + public int Index { get; set; } + + public required string Content { get; set; } + + public required float[] Embedding { get; set; } + + public virtual Document Document { get; set; } = null!; +} \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Models/DocumentChunk.cs b/SqlDatabaseVectorSearch/Models/DocumentChunk.cs index 15bfab7..3753352 100644 --- a/SqlDatabaseVectorSearch/Models/DocumentChunk.cs +++ b/SqlDatabaseVectorSearch/Models/DocumentChunk.cs @@ -1,14 +1,3 @@ -using System.Text.Json; +namespace SqlDatabaseVectorSearch.Models; -namespace SqlDatabaseVectorSearch.Models; - -public record class DocumentChunk(Guid Id, int Index, string Content, float[]? Embedding) -{ - public DocumentChunk(Guid Id, int Index, string Content) : this(Id, Index, Content, (float[]?)null) - { - } - - public DocumentChunk(Guid Id, int Index, string Content, string Embedding) : this(Id, Index, Content, JsonSerializer.Deserialize(Embedding)) - { - } -} +public record class DocumentChunk(Guid Id, int Index, string Content, float[]? Embedding = null); diff --git a/SqlDatabaseVectorSearch/Program.cs b/SqlDatabaseVectorSearch/Program.cs index f613a87..a798f73 100644 --- a/SqlDatabaseVectorSearch/Program.cs +++ b/SqlDatabaseVectorSearch/Program.cs @@ -1,8 +1,9 @@ using Microsoft.AspNetCore.Http.HttpResults; -using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore; using Microsoft.OpenApi.Models; using Microsoft.SemanticKernel; using MinimalHelpers.OpenApi; +using SqlDatabaseVectorSearch.DataAccessLayer; using SqlDatabaseVectorSearch.Models; using SqlDatabaseVectorSearch.Services; using SqlDatabaseVectorSearch.Settings; @@ -18,10 +19,12 @@ var appSettings = builder.Services.ConfigureAndGet(builder.Configur builder.Services.AddSingleton(TimeProvider.System); -builder.Services.AddScoped(_ => +builder.Services.AddSqlServer(builder.Configuration.GetConnectionString("SqlConnection"), options => { - var sqlConnection = new SqlConnection(builder.Configuration.GetConnectionString("SqlConnection")); - return sqlConnection; + options.UseVectorSearch(); +}, options => +{ + options.UseQueryTrackingBehavior(QueryTrackingBehavior.NoTracking); }); builder.Services.AddMemoryCache(); diff --git a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs index f21249d..c725568 100644 --- a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs +++ b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs @@ -1,21 +1,20 @@ using System.Data; -using System.Data.Common; using System.Text; -using System.Text.Json; -using Dapper; -using Microsoft.Data.SqlClient; +using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Options; using Microsoft.SemanticKernel.Embeddings; using Microsoft.SemanticKernel.Text; +using SqlDatabaseVectorSearch.DataAccessLayer; using SqlDatabaseVectorSearch.Models; using SqlDatabaseVectorSearch.Settings; using TinyHelpers.Extensions; using UglyToad.PdfPig; using UglyToad.PdfPig.DocumentLayoutAnalysis.TextExtractor; +using Entities = SqlDatabaseVectorSearch.DataAccessLayer.Entities; namespace SqlDatabaseVectorSearch.Services; -public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGenerationService textEmbeddingGenerationService, ChatService chatService, TimeProvider timeProvider, IOptions appSettingsOptions) +public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingGenerationService textEmbeddingGenerationService, ChatService chatService, TimeProvider timeProvider, IOptions appSettingsOptions) { private readonly AppSettings appSettings = appSettingsOptions.Value; @@ -24,21 +23,16 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene // Extract the contents of the file (currently, only PDF files are supported). var content = await GetContentAsync(stream); - await sqlConnection.OpenAsync(); - await using var transaction = await sqlConnection.BeginTransactionAsync(); + await dbContext.Database.BeginTransactionAsync(); if (documentId.HasValue) { // If the user is importing a document that already exists, delete the previous one. - await DeleteDocumentAsync(documentId.Value, transaction); + await DeleteDocumentAsync(documentId.Value); } - documentId = await sqlConnection.ExecuteScalarAsync($""" - INSERT INTO Documents (Id, [Name], CreationDate) - OUTPUT INSERTED.Id - VALUES (@Id, @Name, @CreationDate); - """, new { Id = documentId.GetValueOrDefault(Guid.NewGuid()), Name = name, CreationDate = timeProvider.GetUtcNow() }, - transaction); + var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() }; + dbContext.Documents.Add(document); // Split the content into chunks and generate the embeddings for each one. var paragraphs = TextChunker.SplitPlainTextParagraphs(TextChunker.SplitPlainTextLines(content, appSettings.MaxTokensPerLine), appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens); @@ -47,54 +41,45 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene // Save the document chunks and the corresponding embedding in the database. foreach (var (paragraph, index) in paragraphs.WithIndex()) { - await sqlConnection.ExecuteAsync($""" - INSERT INTO DocumentChunks (DocumentId, [Index], Content, Embedding) - VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR({embeddings[index].Length}))); - """, new { DocumentId = documentId, Index = index, Content = paragraph, Embedding = JsonSerializer.Serialize(embeddings[index]) }, - transaction); + var documentChunk = new Entities.DocumentChunk { Document = document, Index = index, Content = paragraph!, Embedding = embeddings[index].ToArray() }; + dbContext.DocumentChunks.Add(documentChunk); } - await transaction.CommitAsync(); + await dbContext.SaveChangesAsync(); + await dbContext.Database.CommitTransactionAsync(); - return documentId.Value; + return document.Id; } public async Task> GetDocumentsAsync() { - var documents = await sqlConnection.QueryAsync(""" - SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks WHERE DocumentId = Documents.Id) - FROM Documents - ORDER BY [Name]; - """); + var documents = await dbContext.Documents.OrderBy(d => d.Name) + .Select(d => new Document(d.Id, d.Name, d.CreationDate, d.Chunks.Count)) + .ToListAsync(); return documents; } public async Task> GetDocumentChunksAsync(Guid documentId) { - var documentChunks = await sqlConnection.QueryAsync(""" - SELECT Id, [Index], Content - FROM DocumentChunks - WHERE DocumentId = @DocumentId - ORDER BY [Index]; - """, new { documentId }); + var documentChunks = await dbContext.DocumentChunks.Where(c => c.DocumentId == documentId).OrderBy(c => c.Index) + .Select(c => new DocumentChunk(c.Id, c.Index, c.Content, null)) + .ToListAsync(); return documentChunks; } public async Task GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId) { - var documentChunk = await sqlConnection.QueryFirstOrDefaultAsync(""" - SELECT Id, [Index], Content, CAST(Embedding AS NVARCHAR(MAX)) AS Embedding - FROM DocumentChunks - WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId; - """, new { documentId, documentChunkId }); + var documentChunk = await dbContext.DocumentChunks.Where(c => c.Id == documentChunkId && c.DocumentId == documentId) + .Select(c => new DocumentChunk(c.Id, c.Index, c.Content, c.Embedding)) + .FirstOrDefaultAsync(); return documentChunk; } - public Task DeleteDocumentAsync(Guid documentId, DbTransaction? transaction = null) - => sqlConnection.ExecuteAsync("DELETE FROM Documents WHERE Id = @DocumentId", new { DocumentId = documentId }, transaction); + public Task DeleteDocumentAsync(Guid documentId) + => dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync(); public async Task AskQuestionAsync(Question question, bool reformulate = true) { @@ -104,11 +89,11 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene // Perform Vector Search on SQL Database. var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion); - var chunks = await sqlConnection.QueryAsync($""" - SELECT TOP (@MaxRelevantChunks) Content - FROM DocumentChunks - ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR({questionEmbedding.Length}))); - """, new { appSettings.MaxRelevantChunks, QuestionEmbedding = JsonSerializer.Serialize(questionEmbedding) }); + var chunks = await dbContext.DocumentChunks + .OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray())) + .Select(c => c.Content) + .Take(appSettings.MaxRelevantChunks) + .ToListAsync(); var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion); return new Response(reformulatedQuestion, answer); diff --git a/SqlDatabaseVectorSearch/SqlDatabaseVectorSearch.csproj b/SqlDatabaseVectorSearch/SqlDatabaseVectorSearch.csproj index 5e61966..9f22105 100644 --- a/SqlDatabaseVectorSearch/SqlDatabaseVectorSearch.csproj +++ b/SqlDatabaseVectorSearch/SqlDatabaseVectorSearch.csproj @@ -8,9 +8,10 @@ - - - + + + +