mirror of
https://github.com/marcominerva/SqlDatabaseVectorSearch.git
synced 2026-06-20 12:23:10 +00:00
Refactor to use EF Core for database operations
Refactored the codebase to replace raw SQL connections and Dapper with Entity Framework Core (EF Core). Modified `Program.cs` to configure EF Core services. Refactored `VectorSearchService` to use EF Core for all database operations. Updated project dependencies to remove Dapper and `Microsoft.Data.SqlClient`, and add EF Core packages. Added `ApplicationDbContext` for EF Core context and new `Document` and `DocumentChunk` classes for entity models.
This commit is contained in:
@@ -0,0 +1,51 @@
|
|||||||
|
using EntityFramework.Exceptions.SqlServer;
|
||||||
|
using Microsoft.EntityFrameworkCore;
|
||||||
|
using SqlDatabaseVectorSearch.DataAccessLayer.Entities;
|
||||||
|
|
||||||
|
namespace SqlDatabaseVectorSearch.DataAccessLayer;
|
||||||
|
|
||||||
|
public class ApplicationDbContext(DbContextOptions<ApplicationDbContext> options) : DbContext(options)
|
||||||
|
{
|
||||||
|
public virtual DbSet<Document> Documents { get; set; }
|
||||||
|
|
||||||
|
public virtual DbSet<DocumentChunk> DocumentChunks { get; set; }
|
||||||
|
|
||||||
|
protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
|
||||||
|
{
|
||||||
|
base.OnConfiguring(optionsBuilder);
|
||||||
|
|
||||||
|
optionsBuilder.UseExceptionProcessor();
|
||||||
|
//optionsBuilder.EnableSensitiveDataLogging();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override void OnModelCreating(ModelBuilder modelBuilder)
|
||||||
|
{
|
||||||
|
modelBuilder.Entity<Document>(entity =>
|
||||||
|
{
|
||||||
|
entity.ToTable("Documents");
|
||||||
|
entity.HasKey(e => e.Id);
|
||||||
|
|
||||||
|
entity.Property(e => e.Id).ValueGeneratedOnAdd();
|
||||||
|
entity.Property(e => e.Name)
|
||||||
|
.IsRequired()
|
||||||
|
.HasMaxLength(255);
|
||||||
|
});
|
||||||
|
|
||||||
|
modelBuilder.Entity<DocumentChunk>(entity =>
|
||||||
|
{
|
||||||
|
entity.ToTable("DocumentChunks");
|
||||||
|
entity.HasKey(e => e.Id);
|
||||||
|
|
||||||
|
entity.Property(e => e.Id).ValueGeneratedOnAdd();
|
||||||
|
entity.Property(e => e.Content).IsRequired();
|
||||||
|
entity.Property(e => e.Embedding)
|
||||||
|
.IsRequired()
|
||||||
|
.HasColumnType("vector(1536)");
|
||||||
|
|
||||||
|
entity.HasOne(d => d.Document).WithMany(p => p.Chunks)
|
||||||
|
.HasForeignKey(d => d.DocumentId)
|
||||||
|
.OnDelete(DeleteBehavior.Cascade)
|
||||||
|
.HasConstraintName("FK_DocumentChunks_Documents");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
namespace SqlDatabaseVectorSearch.DataAccessLayer.Entities;
|
||||||
|
|
||||||
|
public class Document
|
||||||
|
{
|
||||||
|
public Guid Id { get; set; }
|
||||||
|
|
||||||
|
public required string Name { get; set; }
|
||||||
|
|
||||||
|
public DateTimeOffset CreationDate { get; set; }
|
||||||
|
|
||||||
|
public virtual ICollection<DocumentChunk> Chunks { get; set; } = [];
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
namespace SqlDatabaseVectorSearch.DataAccessLayer.Entities;
|
||||||
|
|
||||||
|
public class DocumentChunk
|
||||||
|
{
|
||||||
|
public Guid Id { get; set; }
|
||||||
|
|
||||||
|
public Guid DocumentId { get; set; }
|
||||||
|
|
||||||
|
public int Index { get; set; }
|
||||||
|
|
||||||
|
public required string Content { get; set; }
|
||||||
|
|
||||||
|
public required float[] Embedding { get; set; }
|
||||||
|
|
||||||
|
public virtual Document Document { get; set; } = null!;
|
||||||
|
}
|
||||||
@@ -1,14 +1,3 @@
|
|||||||
using System.Text.Json;
|
namespace SqlDatabaseVectorSearch.Models;
|
||||||
|
|
||||||
namespace SqlDatabaseVectorSearch.Models;
|
public record class DocumentChunk(Guid Id, int Index, string Content, float[]? Embedding = null);
|
||||||
|
|
||||||
public record class DocumentChunk(Guid Id, int Index, string Content, float[]? Embedding)
|
|
||||||
{
|
|
||||||
public DocumentChunk(Guid Id, int Index, string Content) : this(Id, Index, Content, (float[]?)null)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
public DocumentChunk(Guid Id, int Index, string Content, string Embedding) : this(Id, Index, Content, JsonSerializer.Deserialize<float[]?>(Embedding))
|
|
||||||
{
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
using Microsoft.AspNetCore.Http.HttpResults;
|
using Microsoft.AspNetCore.Http.HttpResults;
|
||||||
using Microsoft.Data.SqlClient;
|
using Microsoft.EntityFrameworkCore;
|
||||||
using Microsoft.OpenApi.Models;
|
using Microsoft.OpenApi.Models;
|
||||||
using Microsoft.SemanticKernel;
|
using Microsoft.SemanticKernel;
|
||||||
using MinimalHelpers.OpenApi;
|
using MinimalHelpers.OpenApi;
|
||||||
|
using SqlDatabaseVectorSearch.DataAccessLayer;
|
||||||
using SqlDatabaseVectorSearch.Models;
|
using SqlDatabaseVectorSearch.Models;
|
||||||
using SqlDatabaseVectorSearch.Services;
|
using SqlDatabaseVectorSearch.Services;
|
||||||
using SqlDatabaseVectorSearch.Settings;
|
using SqlDatabaseVectorSearch.Settings;
|
||||||
@@ -18,10 +19,12 @@ var appSettings = builder.Services.ConfigureAndGet<AppSettings>(builder.Configur
|
|||||||
|
|
||||||
builder.Services.AddSingleton(TimeProvider.System);
|
builder.Services.AddSingleton(TimeProvider.System);
|
||||||
|
|
||||||
builder.Services.AddScoped(_ =>
|
builder.Services.AddSqlServer<ApplicationDbContext>(builder.Configuration.GetConnectionString("SqlConnection"), options =>
|
||||||
{
|
{
|
||||||
var sqlConnection = new SqlConnection(builder.Configuration.GetConnectionString("SqlConnection"));
|
options.UseVectorSearch();
|
||||||
return sqlConnection;
|
}, options =>
|
||||||
|
{
|
||||||
|
options.UseQueryTrackingBehavior(QueryTrackingBehavior.NoTracking);
|
||||||
});
|
});
|
||||||
|
|
||||||
builder.Services.AddMemoryCache();
|
builder.Services.AddMemoryCache();
|
||||||
|
|||||||
@@ -1,21 +1,20 @@
|
|||||||
using System.Data;
|
using System.Data;
|
||||||
using System.Data.Common;
|
|
||||||
using System.Text;
|
using System.Text;
|
||||||
using System.Text.Json;
|
using Microsoft.EntityFrameworkCore;
|
||||||
using Dapper;
|
|
||||||
using Microsoft.Data.SqlClient;
|
|
||||||
using Microsoft.Extensions.Options;
|
using Microsoft.Extensions.Options;
|
||||||
using Microsoft.SemanticKernel.Embeddings;
|
using Microsoft.SemanticKernel.Embeddings;
|
||||||
using Microsoft.SemanticKernel.Text;
|
using Microsoft.SemanticKernel.Text;
|
||||||
|
using SqlDatabaseVectorSearch.DataAccessLayer;
|
||||||
using SqlDatabaseVectorSearch.Models;
|
using SqlDatabaseVectorSearch.Models;
|
||||||
using SqlDatabaseVectorSearch.Settings;
|
using SqlDatabaseVectorSearch.Settings;
|
||||||
using TinyHelpers.Extensions;
|
using TinyHelpers.Extensions;
|
||||||
using UglyToad.PdfPig;
|
using UglyToad.PdfPig;
|
||||||
using UglyToad.PdfPig.DocumentLayoutAnalysis.TextExtractor;
|
using UglyToad.PdfPig.DocumentLayoutAnalysis.TextExtractor;
|
||||||
|
using Entities = SqlDatabaseVectorSearch.DataAccessLayer.Entities;
|
||||||
|
|
||||||
namespace SqlDatabaseVectorSearch.Services;
|
namespace SqlDatabaseVectorSearch.Services;
|
||||||
|
|
||||||
public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGenerationService textEmbeddingGenerationService, ChatService chatService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions)
|
public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingGenerationService textEmbeddingGenerationService, ChatService chatService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions)
|
||||||
{
|
{
|
||||||
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
||||||
|
|
||||||
@@ -24,21 +23,16 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
|
|||||||
// Extract the contents of the file (currently, only PDF files are supported).
|
// Extract the contents of the file (currently, only PDF files are supported).
|
||||||
var content = await GetContentAsync(stream);
|
var content = await GetContentAsync(stream);
|
||||||
|
|
||||||
await sqlConnection.OpenAsync();
|
await dbContext.Database.BeginTransactionAsync();
|
||||||
await using var transaction = await sqlConnection.BeginTransactionAsync();
|
|
||||||
|
|
||||||
if (documentId.HasValue)
|
if (documentId.HasValue)
|
||||||
{
|
{
|
||||||
// If the user is importing a document that already exists, delete the previous one.
|
// If the user is importing a document that already exists, delete the previous one.
|
||||||
await DeleteDocumentAsync(documentId.Value, transaction);
|
await DeleteDocumentAsync(documentId.Value);
|
||||||
}
|
}
|
||||||
|
|
||||||
documentId = await sqlConnection.ExecuteScalarAsync<Guid>($"""
|
var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() };
|
||||||
INSERT INTO Documents (Id, [Name], CreationDate)
|
dbContext.Documents.Add(document);
|
||||||
OUTPUT INSERTED.Id
|
|
||||||
VALUES (@Id, @Name, @CreationDate);
|
|
||||||
""", new { Id = documentId.GetValueOrDefault(Guid.NewGuid()), Name = name, CreationDate = timeProvider.GetUtcNow() },
|
|
||||||
transaction);
|
|
||||||
|
|
||||||
// Split the content into chunks and generate the embeddings for each one.
|
// Split the content into chunks and generate the embeddings for each one.
|
||||||
var paragraphs = TextChunker.SplitPlainTextParagraphs(TextChunker.SplitPlainTextLines(content, appSettings.MaxTokensPerLine), appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens);
|
var paragraphs = TextChunker.SplitPlainTextParagraphs(TextChunker.SplitPlainTextLines(content, appSettings.MaxTokensPerLine), appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens);
|
||||||
@@ -47,54 +41,45 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
|
|||||||
// Save the document chunks and the corresponding embedding in the database.
|
// Save the document chunks and the corresponding embedding in the database.
|
||||||
foreach (var (paragraph, index) in paragraphs.WithIndex())
|
foreach (var (paragraph, index) in paragraphs.WithIndex())
|
||||||
{
|
{
|
||||||
await sqlConnection.ExecuteAsync($"""
|
var documentChunk = new Entities.DocumentChunk { Document = document, Index = index, Content = paragraph!, Embedding = embeddings[index].ToArray() };
|
||||||
INSERT INTO DocumentChunks (DocumentId, [Index], Content, Embedding)
|
dbContext.DocumentChunks.Add(documentChunk);
|
||||||
VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR({embeddings[index].Length})));
|
|
||||||
""", new { DocumentId = documentId, Index = index, Content = paragraph, Embedding = JsonSerializer.Serialize(embeddings[index]) },
|
|
||||||
transaction);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
await transaction.CommitAsync();
|
await dbContext.SaveChangesAsync();
|
||||||
|
await dbContext.Database.CommitTransactionAsync();
|
||||||
|
|
||||||
return documentId.Value;
|
return document.Id;
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<IEnumerable<Document>> GetDocumentsAsync()
|
public async Task<IEnumerable<Document>> GetDocumentsAsync()
|
||||||
{
|
{
|
||||||
var documents = await sqlConnection.QueryAsync<Document>("""
|
var documents = await dbContext.Documents.OrderBy(d => d.Name)
|
||||||
SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks WHERE DocumentId = Documents.Id)
|
.Select(d => new Document(d.Id, d.Name, d.CreationDate, d.Chunks.Count))
|
||||||
FROM Documents
|
.ToListAsync();
|
||||||
ORDER BY [Name];
|
|
||||||
""");
|
|
||||||
|
|
||||||
return documents;
|
return documents;
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<IEnumerable<DocumentChunk>> GetDocumentChunksAsync(Guid documentId)
|
public async Task<IEnumerable<DocumentChunk>> GetDocumentChunksAsync(Guid documentId)
|
||||||
{
|
{
|
||||||
var documentChunks = await sqlConnection.QueryAsync<DocumentChunk>("""
|
var documentChunks = await dbContext.DocumentChunks.Where(c => c.DocumentId == documentId).OrderBy(c => c.Index)
|
||||||
SELECT Id, [Index], Content
|
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, null))
|
||||||
FROM DocumentChunks
|
.ToListAsync();
|
||||||
WHERE DocumentId = @DocumentId
|
|
||||||
ORDER BY [Index];
|
|
||||||
""", new { documentId });
|
|
||||||
|
|
||||||
return documentChunks;
|
return documentChunks;
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<DocumentChunk?> GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId)
|
public async Task<DocumentChunk?> GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId)
|
||||||
{
|
{
|
||||||
var documentChunk = await sqlConnection.QueryFirstOrDefaultAsync<DocumentChunk>("""
|
var documentChunk = await dbContext.DocumentChunks.Where(c => c.Id == documentChunkId && c.DocumentId == documentId)
|
||||||
SELECT Id, [Index], Content, CAST(Embedding AS NVARCHAR(MAX)) AS Embedding
|
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, c.Embedding))
|
||||||
FROM DocumentChunks
|
.FirstOrDefaultAsync();
|
||||||
WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId;
|
|
||||||
""", new { documentId, documentChunkId });
|
|
||||||
|
|
||||||
return documentChunk;
|
return documentChunk;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Task DeleteDocumentAsync(Guid documentId, DbTransaction? transaction = null)
|
public Task DeleteDocumentAsync(Guid documentId)
|
||||||
=> sqlConnection.ExecuteAsync("DELETE FROM Documents WHERE Id = @DocumentId", new { DocumentId = documentId }, transaction);
|
=> dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync();
|
||||||
|
|
||||||
public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true)
|
public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true)
|
||||||
{
|
{
|
||||||
@@ -104,11 +89,11 @@ public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGene
|
|||||||
// Perform Vector Search on SQL Database.
|
// Perform Vector Search on SQL Database.
|
||||||
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion);
|
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion);
|
||||||
|
|
||||||
var chunks = await sqlConnection.QueryAsync<string>($"""
|
var chunks = await dbContext.DocumentChunks
|
||||||
SELECT TOP (@MaxRelevantChunks) Content
|
.OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray()))
|
||||||
FROM DocumentChunks
|
.Select(c => c.Content)
|
||||||
ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR({questionEmbedding.Length})));
|
.Take(appSettings.MaxRelevantChunks)
|
||||||
""", new { appSettings.MaxRelevantChunks, QuestionEmbedding = JsonSerializer.Serialize(questionEmbedding) });
|
.ToListAsync();
|
||||||
|
|
||||||
var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion);
|
var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion);
|
||||||
return new Response(reformulatedQuestion, answer);
|
return new Response(reformulatedQuestion, answer);
|
||||||
|
|||||||
@@ -8,9 +8,10 @@
|
|||||||
</PropertyGroup>
|
</PropertyGroup>
|
||||||
|
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<PackageReference Include="Dapper" Version="2.1.35" />
|
<PackageReference Include="EFCore.SqlServer.VectorSearch" Version="0.2.0" />
|
||||||
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="8.0.10" />
|
<PackageReference Include="EntityFrameworkCore.Exceptions.SqlServer" Version="8.1.3" />
|
||||||
<PackageReference Include="Microsoft.Data.SqlClient" Version="5.2.2" />
|
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="8.0.10" />
|
||||||
|
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="8.0.10" />
|
||||||
<PackageReference Include="Microsoft.SemanticKernel" Version="1.25.0" />
|
<PackageReference Include="Microsoft.SemanticKernel" Version="1.25.0" />
|
||||||
<PackageReference Include="MinimalHelpers.OpenApi" Version="2.0.17" />
|
<PackageReference Include="MinimalHelpers.OpenApi" Version="2.0.17" />
|
||||||
<PackageReference Include="PdfPig" Version="0.1.9" />
|
<PackageReference Include="PdfPig" Version="0.1.9" />
|
||||||
|
|||||||
Reference in New Issue
Block a user