mirror of
https://github.com/marcominerva/SqlDatabaseVectorSearch.git
synced 2026-06-20 12:23:10 +00:00
Refactor to use the native VECTOR type
This commit is contained in:
@@ -1,18 +1,19 @@
|
||||
using System.Text;
|
||||
using Microsoft.EntityFrameworkCore;
|
||||
using System.Data;
|
||||
using System.Data.Common;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using Microsoft.Data.SqlClient;
|
||||
using Microsoft.Extensions.Options;
|
||||
using Microsoft.SemanticKernel.Embeddings;
|
||||
using Microsoft.SemanticKernel.Text;
|
||||
using SqlDatabaseVectorSearch.DataAccessLayer;
|
||||
using SqlDatabaseVectorSearch.Models;
|
||||
using SqlDatabaseVectorSearch.Settings;
|
||||
using UglyToad.PdfPig;
|
||||
using UglyToad.PdfPig.DocumentLayoutAnalysis.TextExtractor;
|
||||
using Entities = SqlDatabaseVectorSearch.DataAccessLayer.Entities;
|
||||
|
||||
namespace SqlDatabaseVectorSearch.Services;
|
||||
|
||||
public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingGenerationService textEmbeddingGenerationService, ChatService chatService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions)
|
||||
public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGenerationService textEmbeddingGenerationService, ChatService chatService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions)
|
||||
{
|
||||
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
||||
|
||||
@@ -21,16 +22,25 @@ public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingG
|
||||
// Extract the contents of the file (currently, only PDF files are supported).
|
||||
var content = await GetContentAsync(stream);
|
||||
|
||||
await dbContext.Database.BeginTransactionAsync();
|
||||
await sqlConnection.OpenAsync();
|
||||
await using var transaction = await sqlConnection.BeginTransactionAsync();
|
||||
|
||||
if (documentId.HasValue)
|
||||
{
|
||||
// If the user is importing a document that already exists, delete the previous one.
|
||||
await DeleteDocumentAsync(documentId.Value);
|
||||
await DeleteDocumentAsync(documentId.Value, transaction);
|
||||
}
|
||||
|
||||
var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() };
|
||||
dbContext.Documents.Add(document);
|
||||
await using var command = sqlConnection.CreateCommand();
|
||||
command.Transaction = (SqlTransaction)transaction;
|
||||
|
||||
command.CommandText = "INSERT INTO Documents2 (Id, [Name], CreationDate) OUTPUT INSERTED.Id VALUES (@Id, @Name, @CreationDate)";
|
||||
command.Parameters.AddWithValue("@Id", documentId.GetValueOrDefault(Guid.NewGuid()));
|
||||
command.Parameters.AddWithValue("@Name", name);
|
||||
command.Parameters.AddWithValue("@CreationDate", timeProvider.GetUtcNow());
|
||||
|
||||
var insertedId = await command.ExecuteScalarAsync();
|
||||
documentId = (Guid)insertedId!;
|
||||
|
||||
// Split the content into chunks and generate the embeddings for each one.
|
||||
var paragraphs = TextChunker.SplitPlainTextParagraphs(TextChunker.SplitPlainTextLines(content, appSettings.MaxTokensPerLine), appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens);
|
||||
@@ -39,45 +49,105 @@ public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingG
|
||||
var index = 0;
|
||||
foreach (var (paragraph, embedding) in paragraphs.Zip(embeddings, (p, e) => (p, e.ToArray())))
|
||||
{
|
||||
var documentChunk = new Entities.DocumentChunk { Document = document, Index = index++, Content = paragraph, Embedding = embedding };
|
||||
dbContext.DocumentChunks.Add(documentChunk);
|
||||
command.Parameters.Clear();
|
||||
|
||||
command.CommandText = "INSERT INTO DocumentChunks2 (DocumentId, [Index], Content, Embedding) VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR(1536)))";
|
||||
command.Parameters.AddWithValue("@DocumentId", documentId);
|
||||
command.Parameters.AddWithValue("@Index", index++);
|
||||
command.Parameters.AddWithValue("@Content", paragraph);
|
||||
command.Parameters.AddWithValue("@Embedding", JsonSerializer.Serialize(embedding));
|
||||
|
||||
await command.ExecuteNonQueryAsync();
|
||||
}
|
||||
|
||||
await dbContext.SaveChangesAsync();
|
||||
await dbContext.Database.CommitTransactionAsync();
|
||||
await transaction.CommitAsync();
|
||||
|
||||
return document.Id;
|
||||
return documentId.Value;
|
||||
}
|
||||
|
||||
public async Task<IEnumerable<Document>> GetDocumentsAsync()
|
||||
{
|
||||
var documents = await dbContext.Documents.OrderBy(d => d.Name)
|
||||
.Select(d => new Document(d.Id, d.Name, d.CreationDate, d.Chunks.Count))
|
||||
.ToListAsync();
|
||||
await sqlConnection.OpenAsync();
|
||||
await using var command = sqlConnection.CreateCommand();
|
||||
|
||||
command.CommandText = "SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks2 WHERE DocumentId = Documents2.Id) FROM Documents2 ORDER BY [Name]";
|
||||
|
||||
var documents = new List<Document>();
|
||||
|
||||
using var reader = await command.ExecuteReaderAsync();
|
||||
while (await reader.ReadAsync())
|
||||
{
|
||||
var id = reader.GetGuid(0);
|
||||
var name = reader.GetString(1);
|
||||
var creationDate = reader.GetDateTimeOffset(2);
|
||||
var chunkCount = reader.GetInt32(3);
|
||||
|
||||
documents.Add(new(id, name, creationDate, chunkCount));
|
||||
}
|
||||
|
||||
return documents;
|
||||
}
|
||||
|
||||
public async Task<IEnumerable<DocumentChunk>> GetDocumentChunksAsync(Guid documentId)
|
||||
{
|
||||
var documentChunks = await dbContext.DocumentChunks.Where(c => c.DocumentId == documentId).OrderBy(c => c.Index)
|
||||
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, null))
|
||||
.ToListAsync();
|
||||
await sqlConnection.OpenAsync();
|
||||
await using var command = sqlConnection.CreateCommand();
|
||||
|
||||
command.CommandText = "SELECT Id, [Index], Content FROM DocumentChunks2 WHERE DocumentId = @DocumentId ORDER BY [Index]";
|
||||
command.Parameters.AddWithValue("@DocumentId", documentId);
|
||||
|
||||
var documentChunks = new List<DocumentChunk>();
|
||||
|
||||
using var reader = await command.ExecuteReaderAsync();
|
||||
while (await reader.ReadAsync())
|
||||
{
|
||||
var id = reader.GetGuid(0);
|
||||
var index = reader.GetInt32(1);
|
||||
var content = reader.GetString(2);
|
||||
|
||||
documentChunks.Add(new(id, index, content, null));
|
||||
}
|
||||
|
||||
return documentChunks;
|
||||
}
|
||||
|
||||
public async Task<DocumentChunk?> GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId)
|
||||
{
|
||||
var documentChunk = await dbContext.DocumentChunks.Where(c => c.Id == documentChunkId && c.DocumentId == documentId)
|
||||
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, c.Embedding))
|
||||
.FirstOrDefaultAsync();
|
||||
await sqlConnection.OpenAsync();
|
||||
await using var command = sqlConnection.CreateCommand();
|
||||
|
||||
return documentChunk;
|
||||
command.CommandText = "SELECT [Index], Content, CAST(Embedding AS NVARCHAR(MAX)) FROM DocumentChunks2 WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId";
|
||||
command.Parameters.AddWithValue("@DocumentChunkId", documentChunkId);
|
||||
command.Parameters.AddWithValue("@DocumentId", documentId);
|
||||
|
||||
using var reader = await command.ExecuteReaderAsync();
|
||||
if (reader.HasRows && await reader.ReadAsync())
|
||||
{
|
||||
var index = reader.GetInt32(0);
|
||||
var content = reader.GetString(1);
|
||||
var embedding = JsonSerializer.Deserialize<float[]>(reader.GetString(2))!;
|
||||
|
||||
return new(documentChunkId, index, content, embedding);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
public Task DeleteDocumentAsync(Guid documentId)
|
||||
=> dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync();
|
||||
public async Task DeleteDocumentAsync(Guid documentId, DbTransaction? transaction = null)
|
||||
{
|
||||
if (sqlConnection.State == ConnectionState.Closed)
|
||||
{
|
||||
await sqlConnection.OpenAsync();
|
||||
}
|
||||
|
||||
using var command = sqlConnection.CreateCommand();
|
||||
command.Transaction = transaction as SqlTransaction;
|
||||
|
||||
command.CommandText = "DELETE FROM Documents2 WHERE Id = @DocumentId";
|
||||
command.Parameters.AddWithValue("@DocumentId", documentId);
|
||||
|
||||
await command.ExecuteNonQueryAsync();
|
||||
}
|
||||
|
||||
public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true)
|
||||
{
|
||||
@@ -87,11 +157,21 @@ public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingG
|
||||
// Perform Vector Search on SQL Database.
|
||||
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion);
|
||||
|
||||
var chunks = await dbContext.DocumentChunks
|
||||
.OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray()))
|
||||
.Select(c => c.Content)
|
||||
.Take(appSettings.MaxRelevantChunks)
|
||||
.ToListAsync();
|
||||
await sqlConnection.OpenAsync();
|
||||
await using var command = sqlConnection.CreateCommand();
|
||||
|
||||
command.CommandText = "SELECT TOP (@MaxRelevantChunks) Content FROM DocumentChunks2 ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR(1536)))";
|
||||
command.Parameters.AddWithValue("@MaxRelevantChunks", appSettings.MaxRelevantChunks);
|
||||
command.Parameters.AddWithValue("@QuestionEmbedding", JsonSerializer.Serialize(questionEmbedding));
|
||||
|
||||
var chunks = new List<string>();
|
||||
|
||||
using var reader = await command.ExecuteReaderAsync();
|
||||
while (await reader.ReadAsync())
|
||||
{
|
||||
var content = reader.GetString(0);
|
||||
chunks.Add(content);
|
||||
}
|
||||
|
||||
var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion);
|
||||
return new Response(reformulatedQuestion, answer);
|
||||
|
||||
Reference in New Issue
Block a user