Refactor to use the native VECTOR type

This commit is contained in:
Marco Minerva
2024-09-30 17:08:28 +02:00
parent 1ef2076ba6
commit 3e95251485
7 changed files with 129 additions and 126 deletions
+13 -7
View File
@@ -1,26 +1,32 @@
CREATE TABLE [dbo].[DocumentChunks]( CREATE TABLE [dbo].[DocumentChunks2](
[Id] [uniqueidentifier] NOT NULL, [Id] [uniqueidentifier] NOT NULL,
[DocumentId] [uniqueidentifier] NOT NULL, [DocumentId] [uniqueidentifier] NOT NULL,
[Index] [int] NOT NULL, [Index] [int] NOT NULL,
[Content] [nvarchar](max) NOT NULL, [Content] [nvarchar](max) NOT NULL,
[Embedding] [varbinary](8000) NOT NULL, [Embedding] [vector](1536) NOT NULL,
CONSTRAINT [PK_DocumentChunks] PRIMARY KEY CLUSTERED CONSTRAINT [PK_DocumentChunks2] PRIMARY KEY CLUSTERED
( (
[Id] ASC [Id] ASC
)) ))
GO GO
CREATE TABLE [dbo].[Documents]( CREATE TABLE [dbo].[Documents2](
[Id] [uniqueidentifier] NOT NULL, [Id] [uniqueidentifier] NOT NULL,
[Name] [nvarchar](255) NOT NULL, [Name] [nvarchar](255) NOT NULL,
[CreationDate] [datetimeoffset](7) NOT NULL, [CreationDate] [datetimeoffset](7) NOT NULL,
CONSTRAINT [PK_Documents] PRIMARY KEY CLUSTERED CONSTRAINT [PK_Documents2] PRIMARY KEY CLUSTERED
( (
[Id] ASC [Id] ASC
)) ))
GO GO
ALTER TABLE [dbo].[DocumentChunks] WITH CHECK ADD CONSTRAINT [FK_DocumentChunks_Documents] FOREIGN KEY([DocumentId]) ALTER TABLE [dbo].[DocumentChunks2] WITH CHECK ADD CONSTRAINT [FK_DocumentChunks2_Documents2] FOREIGN KEY([DocumentId])
REFERENCES [dbo].[Documents] ([Id]) REFERENCES [dbo].[Documents2] ([Id])
ON DELETE CASCADE ON DELETE CASCADE
GO GO
ALTER TABLE [dbo].[Documents2] ADD CONSTRAINT [DF_Documents2_Id] DEFAULT (newsequentialid()) FOR [Id]
GO
ALTER TABLE [dbo].[DocumentChunks2] ADD CONSTRAINT [DF_DocumentChunks2_Id] DEFAULT (newsequentialid()) FOR [Id]
GO
@@ -1,52 +0,0 @@
using EntityFramework.Exceptions.SqlServer;
using Microsoft.EntityFrameworkCore;
using SqlDatabaseVectorSearch.DataAccessLayer.Entities;
namespace SqlDatabaseVectorSearch.DataAccessLayer;
public class ApplicationDbContext(DbContextOptions<ApplicationDbContext> options) : DbContext(options)
{
public virtual DbSet<Document> Documents { get; set; }
public virtual DbSet<DocumentChunk> DocumentChunks { get; set; }
protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
{
base.OnConfiguring(optionsBuilder);
optionsBuilder.UseExceptionProcessor();
//optionsBuilder.EnableSensitiveDataLogging();
}
protected override void OnModelCreating(ModelBuilder modelBuilder)
{
modelBuilder.Entity<Document>(entity =>
{
entity.ToTable("Documents");
entity.HasKey(e => e.Id);
entity.Property(e => e.Id).ValueGeneratedOnAdd();
entity.Property(e => e.Name)
.IsRequired()
.HasMaxLength(255);
});
modelBuilder.Entity<DocumentChunk>(entity =>
{
entity.ToTable("DocumentChunks");
entity.HasKey(e => e.Id);
entity.Property(e => e.Id).ValueGeneratedOnAdd();
entity.Property(e => e.Content).IsRequired();
entity.Property(e => e.Embedding)
.IsRequired()
.HasMaxLength(8000)
.IsVector();
entity.HasOne(d => d.Document).WithMany(p => p.Chunks)
.HasForeignKey(d => d.DocumentId)
.OnDelete(DeleteBehavior.Cascade)
.HasConstraintName("FK_DocumentChunks_Documents");
});
}
}
@@ -1,12 +0,0 @@
namespace SqlDatabaseVectorSearch.DataAccessLayer.Entities;
public class Document
{
public Guid Id { get; set; }
public required string Name { get; set; }
public DateTimeOffset CreationDate { get; set; }
public virtual ICollection<DocumentChunk> Chunks { get; set; } = [];
}
@@ -1,16 +0,0 @@
namespace SqlDatabaseVectorSearch.DataAccessLayer.Entities;
public class DocumentChunk
{
public Guid Id { get; set; }
public Guid DocumentId { get; set; }
public int Index { get; set; }
public required string Content { get; set; }
public required float[] Embedding { get; set; }
public virtual Document Document { get; set; } = null!;
}
+4 -6
View File
@@ -1,9 +1,9 @@
using Microsoft.AspNetCore.Http.HttpResults; using Microsoft.AspNetCore.Http.HttpResults;
using Microsoft.Data.SqlClient;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using Microsoft.OpenApi.Models; using Microsoft.OpenApi.Models;
using Microsoft.SemanticKernel; using Microsoft.SemanticKernel;
using MinimalHelpers.OpenApi; using MinimalHelpers.OpenApi;
using SqlDatabaseVectorSearch.DataAccessLayer;
using SqlDatabaseVectorSearch.Models; using SqlDatabaseVectorSearch.Models;
using SqlDatabaseVectorSearch.Services; using SqlDatabaseVectorSearch.Services;
using SqlDatabaseVectorSearch.Settings; using SqlDatabaseVectorSearch.Settings;
@@ -19,12 +19,10 @@ var appSettings = builder.Services.ConfigureAndGet<AppSettings>(builder.Configur
builder.Services.AddSingleton(TimeProvider.System); builder.Services.AddSingleton(TimeProvider.System);
builder.Services.AddSqlServer<ApplicationDbContext>(builder.Configuration.GetConnectionString("SqlConnection"), options => builder.Services.AddScoped(_ =>
{ {
options.UseVectorSearch(); var sqlConnection = new SqlConnection(builder.Configuration.GetConnectionString("SqlConnection"));
}, options => return sqlConnection;
{
options.UseQueryTrackingBehavior(QueryTrackingBehavior.NoTracking);
}); });
builder.Services.AddMemoryCache(); builder.Services.AddMemoryCache();
@@ -1,18 +1,19 @@
using System.Text; using System.Data;
using Microsoft.EntityFrameworkCore; using System.Data.Common;
using System.Text;
using System.Text.Json;
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel.Embeddings; using Microsoft.SemanticKernel.Embeddings;
using Microsoft.SemanticKernel.Text; using Microsoft.SemanticKernel.Text;
using SqlDatabaseVectorSearch.DataAccessLayer;
using SqlDatabaseVectorSearch.Models; using SqlDatabaseVectorSearch.Models;
using SqlDatabaseVectorSearch.Settings; using SqlDatabaseVectorSearch.Settings;
using UglyToad.PdfPig; using UglyToad.PdfPig;
using UglyToad.PdfPig.DocumentLayoutAnalysis.TextExtractor; using UglyToad.PdfPig.DocumentLayoutAnalysis.TextExtractor;
using Entities = SqlDatabaseVectorSearch.DataAccessLayer.Entities;
namespace SqlDatabaseVectorSearch.Services; namespace SqlDatabaseVectorSearch.Services;
public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingGenerationService textEmbeddingGenerationService, ChatService chatService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions) public class VectorSearchService(SqlConnection sqlConnection, ITextEmbeddingGenerationService textEmbeddingGenerationService, ChatService chatService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions)
{ {
private readonly AppSettings appSettings = appSettingsOptions.Value; private readonly AppSettings appSettings = appSettingsOptions.Value;
@@ -21,16 +22,25 @@ public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingG
// Extract the contents of the file (currently, only PDF files are supported). // Extract the contents of the file (currently, only PDF files are supported).
var content = await GetContentAsync(stream); var content = await GetContentAsync(stream);
await dbContext.Database.BeginTransactionAsync(); await sqlConnection.OpenAsync();
await using var transaction = await sqlConnection.BeginTransactionAsync();
if (documentId.HasValue) if (documentId.HasValue)
{ {
// If the user is importing a document that already exists, delete the previous one. // If the user is importing a document that already exists, delete the previous one.
await DeleteDocumentAsync(documentId.Value); await DeleteDocumentAsync(documentId.Value, transaction);
} }
var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() }; await using var command = sqlConnection.CreateCommand();
dbContext.Documents.Add(document); command.Transaction = (SqlTransaction)transaction;
command.CommandText = "INSERT INTO Documents2 (Id, [Name], CreationDate) OUTPUT INSERTED.Id VALUES (@Id, @Name, @CreationDate)";
command.Parameters.AddWithValue("@Id", documentId.GetValueOrDefault(Guid.NewGuid()));
command.Parameters.AddWithValue("@Name", name);
command.Parameters.AddWithValue("@CreationDate", timeProvider.GetUtcNow());
var insertedId = await command.ExecuteScalarAsync();
documentId = (Guid)insertedId!;
// Split the content into chunks and generate the embeddings for each one. // Split the content into chunks and generate the embeddings for each one.
var paragraphs = TextChunker.SplitPlainTextParagraphs(TextChunker.SplitPlainTextLines(content, appSettings.MaxTokensPerLine), appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens); var paragraphs = TextChunker.SplitPlainTextParagraphs(TextChunker.SplitPlainTextLines(content, appSettings.MaxTokensPerLine), appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens);
@@ -39,45 +49,105 @@ public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingG
var index = 0; var index = 0;
foreach (var (paragraph, embedding) in paragraphs.Zip(embeddings, (p, e) => (p, e.ToArray()))) foreach (var (paragraph, embedding) in paragraphs.Zip(embeddings, (p, e) => (p, e.ToArray())))
{ {
var documentChunk = new Entities.DocumentChunk { Document = document, Index = index++, Content = paragraph, Embedding = embedding }; command.Parameters.Clear();
dbContext.DocumentChunks.Add(documentChunk);
command.CommandText = "INSERT INTO DocumentChunks2 (DocumentId, [Index], Content, Embedding) VALUES (@DocumentId, @Index, @Content, CAST(@Embedding AS VECTOR(1536)))";
command.Parameters.AddWithValue("@DocumentId", documentId);
command.Parameters.AddWithValue("@Index", index++);
command.Parameters.AddWithValue("@Content", paragraph);
command.Parameters.AddWithValue("@Embedding", JsonSerializer.Serialize(embedding));
await command.ExecuteNonQueryAsync();
} }
await dbContext.SaveChangesAsync(); await transaction.CommitAsync();
await dbContext.Database.CommitTransactionAsync();
return document.Id; return documentId.Value;
} }
public async Task<IEnumerable<Document>> GetDocumentsAsync() public async Task<IEnumerable<Document>> GetDocumentsAsync()
{ {
var documents = await dbContext.Documents.OrderBy(d => d.Name) await sqlConnection.OpenAsync();
.Select(d => new Document(d.Id, d.Name, d.CreationDate, d.Chunks.Count)) await using var command = sqlConnection.CreateCommand();
.ToListAsync();
command.CommandText = "SELECT Id, [Name], CreationDate, ChunkCount = (SELECT COUNT(*) FROM DocumentChunks2 WHERE DocumentId = Documents2.Id) FROM Documents2 ORDER BY [Name]";
var documents = new List<Document>();
using var reader = await command.ExecuteReaderAsync();
while (await reader.ReadAsync())
{
var id = reader.GetGuid(0);
var name = reader.GetString(1);
var creationDate = reader.GetDateTimeOffset(2);
var chunkCount = reader.GetInt32(3);
documents.Add(new(id, name, creationDate, chunkCount));
}
return documents; return documents;
} }
public async Task<IEnumerable<DocumentChunk>> GetDocumentChunksAsync(Guid documentId) public async Task<IEnumerable<DocumentChunk>> GetDocumentChunksAsync(Guid documentId)
{ {
var documentChunks = await dbContext.DocumentChunks.Where(c => c.DocumentId == documentId).OrderBy(c => c.Index) await sqlConnection.OpenAsync();
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, null)) await using var command = sqlConnection.CreateCommand();
.ToListAsync();
command.CommandText = "SELECT Id, [Index], Content FROM DocumentChunks2 WHERE DocumentId = @DocumentId ORDER BY [Index]";
command.Parameters.AddWithValue("@DocumentId", documentId);
var documentChunks = new List<DocumentChunk>();
using var reader = await command.ExecuteReaderAsync();
while (await reader.ReadAsync())
{
var id = reader.GetGuid(0);
var index = reader.GetInt32(1);
var content = reader.GetString(2);
documentChunks.Add(new(id, index, content, null));
}
return documentChunks; return documentChunks;
} }
public async Task<DocumentChunk?> GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId) public async Task<DocumentChunk?> GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId)
{ {
var documentChunk = await dbContext.DocumentChunks.Where(c => c.Id == documentChunkId && c.DocumentId == documentId) await sqlConnection.OpenAsync();
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, c.Embedding)) await using var command = sqlConnection.CreateCommand();
.FirstOrDefaultAsync();
return documentChunk; command.CommandText = "SELECT [Index], Content, CAST(Embedding AS NVARCHAR(MAX)) FROM DocumentChunks2 WHERE Id = @DocumentChunkId AND DocumentId = @DocumentId";
command.Parameters.AddWithValue("@DocumentChunkId", documentChunkId);
command.Parameters.AddWithValue("@DocumentId", documentId);
using var reader = await command.ExecuteReaderAsync();
if (reader.HasRows && await reader.ReadAsync())
{
var index = reader.GetInt32(0);
var content = reader.GetString(1);
var embedding = JsonSerializer.Deserialize<float[]>(reader.GetString(2))!;
return new(documentChunkId, index, content, embedding);
}
return null;
} }
public Task DeleteDocumentAsync(Guid documentId) public async Task DeleteDocumentAsync(Guid documentId, DbTransaction? transaction = null)
=> dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync(); {
if (sqlConnection.State == ConnectionState.Closed)
{
await sqlConnection.OpenAsync();
}
using var command = sqlConnection.CreateCommand();
command.Transaction = transaction as SqlTransaction;
command.CommandText = "DELETE FROM Documents2 WHERE Id = @DocumentId";
command.Parameters.AddWithValue("@DocumentId", documentId);
await command.ExecuteNonQueryAsync();
}
public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true) public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true)
{ {
@@ -87,11 +157,21 @@ public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingG
// Perform Vector Search on SQL Database. // Perform Vector Search on SQL Database.
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion); var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion);
var chunks = await dbContext.DocumentChunks await sqlConnection.OpenAsync();
.OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray())) await using var command = sqlConnection.CreateCommand();
.Select(c => c.Content)
.Take(appSettings.MaxRelevantChunks) command.CommandText = "SELECT TOP (@MaxRelevantChunks) Content FROM DocumentChunks2 ORDER BY VECTOR_DISTANCE('cosine', Embedding, CAST(@QuestionEmbedding AS VECTOR(1536)))";
.ToListAsync(); command.Parameters.AddWithValue("@MaxRelevantChunks", appSettings.MaxRelevantChunks);
command.Parameters.AddWithValue("@QuestionEmbedding", JsonSerializer.Serialize(questionEmbedding));
var chunks = new List<string>();
using var reader = await command.ExecuteReaderAsync();
while (await reader.ReadAsync())
{
var content = reader.GetString(0);
chunks.Add(content);
}
var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion); var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion);
return new Response(reformulatedQuestion, answer); return new Response(reformulatedQuestion, answer);
@@ -8,10 +8,9 @@
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="EFCore.SqlServer.VectorSearch" Version="0.1.1" />
<PackageReference Include="EntityFrameworkCore.Exceptions.SqlServer" Version="8.1.3" /> <PackageReference Include="EntityFrameworkCore.Exceptions.SqlServer" Version="8.1.3" />
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="8.0.8" /> <PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="8.0.8" />
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="8.0.8" /> <PackageReference Include="Microsoft.Data.SqlClient" Version="5.2.2" />
<PackageReference Include="Microsoft.SemanticKernel" Version="1.21.1" /> <PackageReference Include="Microsoft.SemanticKernel" Version="1.21.1" />
<PackageReference Include="MinimalHelpers.OpenApi" Version="2.0.15" /> <PackageReference Include="MinimalHelpers.OpenApi" Version="2.0.15" />
<PackageReference Include="PdfPig" Version="0.1.8" /> <PackageReference Include="PdfPig" Version="0.1.8" />