Refactor to use EF Core for database operations

Refactored the codebase to replace raw SQL connections and Dapper with Entity Framework Core (EF Core). Modified `Program.cs` to configure EF Core services. Refactored `VectorSearchService` to use EF Core for all database operations. Updated project dependencies to remove Dapper and `Microsoft.Data.SqlClient`, and add EF Core packages. Added `ApplicationDbContext` for EF Core context and new `Document` and `DocumentChunk` classes for entity models.
This commit is contained in:
Marco Minerva
2024-10-31 15:16:38 +01:00
parent 1d8fa7db12
commit 0435f042f1
7 changed files with 121 additions and 64 deletions
@@ -1,21 +1,20 @@
using System.Data;
using System.Data.Common;
using System.Text;
using System.Text.Json;
using Dapper;
using Microsoft.Data.SqlClient;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel.Embeddings;
using Microsoft.SemanticKernel.Text;
using SqlDatabaseVectorSearch.DataAccessLayer;
using SqlDatabaseVectorSearch.Models;
using SqlDatabaseVectorSearch.Settings;
using TinyHelpers.Extensions;
using UglyToad.PdfPig;
using UglyToad.PdfPig.DocumentLayoutAnalysis.TextExtractor;
using Entities = SqlDatabaseVectorSearch.DataAccessLayer.Entities;
namespace SqlDatabaseVectorSearch.Services;
public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGenerationService textEmbeddingGenerationService, ChatService chatService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions)
public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingGenerationService textEmbeddingGenerationService, ChatService chatService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions)
{
private readonly AppSettings appSettings = appSettingsOptions.Value;
@@ -24,21 +23,16 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
// Extract the contents of the file (currently, only PDF files are supported).
var content = await GetContentAsync(stream);
await sqlConnection.OpenAsync();
await using var transaction = await sqlConnection.BeginTransactionAsync();
await dbContext.Database.BeginTransactionAsync();
if (documentId.HasValue)
{
// If the user is importing a document that already exists, delete the previous one.
await DeleteDocumentAsync(documentId.Value, transaction);
await DeleteDocumentAsync(documentId.Value);
}
documentId = await sqlConnection.ExecuteScalarAsync<Guid>($"""
INSERT INTO Documents (Id, [Name], CreationDate)
OUTPUT INSERTED.Id
VALUES (@Id, @Name, @CreationDate);
""", new { Id = documentId.GetValueOrDefault(Guid.NewGuid()), Name = name, CreationDate = timeProvider.GetUtcNow() },
transaction);
var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() };
dbContext.Documents.Add(document);
// Split the content into chunks and generate the embeddings for each one.
var paragraphs = TextChunker.SplitPlainTextParagraphs(TextChunker.SplitPlainTextLines(content, appSettings.MaxTokensPerLine), appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens);
@@ -47,54 +41,45 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
// Save the document chunks and the corresponding embedding in the database.
foreach (var (paragraph, index) in paragraphs.WithIndex())
{
await sqlConnection.ExecuteAsync($"""
INSERT INTO DocumentChunks (DocumentId, [Index], Content, Embedding)
VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR({embeddings[index].Length})));
""", new { DocumentId = documentId, Index = index, Content = paragraph, Embedding = JsonSerializer.Serialize(embeddings[index]) },
transaction);
var documentChunk = new Entities.DocumentChunk { Document = document, Index = index, Content = paragraph!, Embedding = embeddings[index].ToArray() };
dbContext.DocumentChunks.Add(documentChunk);
}
await transaction.CommitAsync();
await dbContext.SaveChangesAsync();
await dbContext.Database.CommitTransactionAsync();
return documentId.Value;
return document.Id;
}
public async Task<IEnumerable<Document>> GetDocumentsAsync()
{
var documents = await sqlConnection.QueryAsync<Document>("""
SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks WHERE DocumentId = Documents.Id)
FROM Documents
ORDER BY [Name];
""");
var documents = await dbContext.Documents.OrderBy(d => d.Name)
.Select(d => new Document(d.Id, d.Name, d.CreationDate, d.Chunks.Count))
.ToListAsync();
return documents;
}
public async Task<IEnumerable<DocumentChunk>> GetDocumentChunksAsync(Guid documentId)
{
var documentChunks = await sqlConnection.QueryAsync<DocumentChunk>("""
SELECT Id, [Index], Content
FROM DocumentChunks
WHERE DocumentId = @DocumentId
ORDER BY [Index];
""", new { documentId });
var documentChunks = await dbContext.DocumentChunks.Where(c => c.DocumentId == documentId).OrderBy(c => c.Index)
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, null))
.ToListAsync();
return documentChunks;
}
public async Task<DocumentChunk?> GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId)
{
var documentChunk = await sqlConnection.QueryFirstOrDefaultAsync<DocumentChunk>("""
SELECT Id, [Index], Content, CAST(Embedding AS NVARCHAR(MAX)) AS Embedding
FROM DocumentChunks
WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId;
""", new { documentId, documentChunkId });
var documentChunk = await dbContext.DocumentChunks.Where(c => c.Id == documentChunkId && c.DocumentId == documentId)
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, c.Embedding))
.FirstOrDefaultAsync();
return documentChunk;
}
public Task DeleteDocumentAsync(Guid documentId, DbTransaction? transaction = null)
=> sqlConnection.ExecuteAsync("DELETE FROM Documents WHERE Id = @DocumentId", new { DocumentId = documentId }, transaction);
public Task DeleteDocumentAsync(Guid documentId)
=> dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync();
public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true)
{
@@ -104,11 +89,11 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
// Perform Vector Search on SQL Database.
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion);
var chunks = await sqlConnection.QueryAsync<string>($"""
SELECT TOP (@MaxRelevantChunks) Content
FROM DocumentChunks
ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR({questionEmbedding.Length})));
""", new { appSettings.MaxRelevantChunks, QuestionEmbedding = JsonSerializer.Serialize(questionEmbedding) });
var chunks = await dbContext.DocumentChunks
.OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray()))
.Select(c => c.Content)
.Take(appSettings.MaxRelevantChunks)
.ToListAsync();
var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion);
return new Response(reformulatedQuestion, answer);