diff --git a/SqlDatabaseVectorSearch/Components/Pages/Documents.razor b/SqlDatabaseVectorSearch/Components/Pages/Documents.razor index eec9704..2eca5f6 100644 --- a/SqlDatabaseVectorSearch/Components/Pages/Documents.razor +++ b/SqlDatabaseVectorSearch/Components/Pages/Documents.razor @@ -177,7 +177,7 @@ else var vectorSearchService = scope.ServiceProvider.GetRequiredService(); var documentId = string.IsNullOrWhiteSpace(Model.DocumentId) ? null : (Guid?)Guid.Parse(Model.DocumentId); - await vectorSearchService.ImportAsync(stream, fileName, MimeUtility.GetMimeMapping(fileName), documentId); + //await vectorSearchService.ImportAsync(stream, fileName, MimeUtility.GetMimeMapping(fileName), documentId); ToastService.Notify(await CreateToastMessageAsync(ToastType.Success, "Upload document", $"The document {fileName} has been successfully uploaded and indexed.")); diff --git a/SqlDatabaseVectorSearch/Endpoints/DocumentEndpoints.cs b/SqlDatabaseVectorSearch/Endpoints/DocumentEndpoints.cs index 6862f09..f42caeb 100644 --- a/SqlDatabaseVectorSearch/Endpoints/DocumentEndpoints.cs +++ b/SqlDatabaseVectorSearch/Endpoints/DocumentEndpoints.cs @@ -1,9 +1,9 @@  using System.ComponentModel; using Microsoft.AspNetCore.Http.HttpResults; -using MimeMapping; using SqlDatabaseVectorSearch.Models; using SqlDatabaseVectorSearch.Services; +using SqlDatabaseVectorSearch.Workflows; namespace SqlDatabaseVectorSearch.Endpoints; @@ -23,12 +23,8 @@ public class DocumentEndpoints : IEndpointRouteHandlerBuilder documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchService vectorSearchService, CancellationToken cancellationToken, [Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) => { - using var stream = file.OpenReadStream(); - - // Note: file.ContentType is not 100% reliable (for example, for markdown file). - var response = await vectorSearchService.ImportAsync(stream, file.FileName, MimeUtility.GetMimeMapping(file.FileName), documentId, cancellationToken); - - return TypedResults.Ok(response); + var result = await vectorSearchService.ImportAsync(new FormFileEmbeddingRequest(file, documentId), cancellationToken); + return TypedResults.Ok(result); }) .DisableAntiforgery() .ProducesProblem(StatusCodes.Status400BadRequest) diff --git a/SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs b/SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs deleted file mode 100644 index 252018a..0000000 --- a/SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs +++ /dev/null @@ -1,3 +0,0 @@ -namespace SqlDatabaseVectorSearch.Models; - -public record class ImportDocumentResponse(Guid DocumentId, int EmbeddingTokenCount); diff --git a/SqlDatabaseVectorSearch/Program.cs b/SqlDatabaseVectorSearch/Program.cs index e99238c..61f0553 100644 --- a/SqlDatabaseVectorSearch/Program.cs +++ b/SqlDatabaseVectorSearch/Program.cs @@ -2,8 +2,11 @@ using System.ClientModel; using System.Net.Mime; using System.Text.Json.Serialization; using FluentValidation; +using Microsoft.Agents.AI.Hosting; +using Microsoft.Agents.AI.Workflows; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; using OpenAI; using OpenAI.Responses; using SqlDatabaseVectorSearch.Components; @@ -13,6 +16,7 @@ using SqlDatabaseVectorSearch.Extensions; using SqlDatabaseVectorSearch.Services; using SqlDatabaseVectorSearch.Settings; using SqlDatabaseVectorSearch.TextChunkers; +using SqlDatabaseVectorSearch.Workflows; using TinyHelpers.AspNetCore.Extensions; using TinyHelpers.AspNetCore.OpenApi; @@ -77,6 +81,11 @@ builder.Services.AddChatClient(_ => return chatClient; }); +// Semantic Kernel is used to generate embeddings and to reformulate questions taking into account all the previous interactions, +// so that embeddings themselves can be generated more accurately. +builder.Services.AddKernel() + .AddAzureOpenAIChatCompletion(aiSettings.ChatCompletion.Deployment, aiSettings.ChatCompletion.Endpoint, aiSettings.ChatCompletion.ApiKey, modelId: aiSettings.ChatCompletion.ModelId); + builder.Services.AddKeyedSingleton(MediaTypeNames.Application.Pdf); builder.Services.AddKeyedSingleton("application/vnd.openxmlformats-officedocument.wordprocessingml.document"); builder.Services.AddKeyedSingleton(MediaTypeNames.Text.Plain); @@ -91,6 +100,25 @@ builder.Services.AddSingleton(); builder.Services.AddScoped(); builder.Services.AddScoped(); +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); +builder.Services.AddScoped(); // This executor is registered as scoped because it uses the DbContext, which is also scoped. + +builder.AddWorkflow("EmbeddingWorkflow", (services, key) => +{ + var formfileToConversionRequestExecutor = services.GetRequiredService(); + var generateEmbeddingExecutor = services.GetRequiredService(); + var storeEmbeddingExecutor = services.GetRequiredService(); + + var workflow = new WorkflowBuilder(formfileToConversionRequestExecutor).WithName(key) + .AddEdge(formfileToConversionRequestExecutor, generateEmbeddingExecutor) + .AddEdge(generateEmbeddingExecutor, storeEmbeddingExecutor) + .WithOutputFrom(storeEmbeddingExecutor) + .Build(validateOrphans: true); + + return workflow; +}, ServiceLifetime.Scoped); + builder.Services.AddOpenApi(options => { options.RemoveServerList(); diff --git a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs index f18d39e..8f5dbd7 100644 --- a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs +++ b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs @@ -2,84 +2,37 @@ using System.Runtime.CompilerServices; using System.Text; using System.Text.RegularExpressions; +using Microsoft.Agents.AI.Workflows; using Microsoft.Data.SqlTypes; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.AI; using Microsoft.Extensions.Options; -using SqlDatabaseVectorSearch.ContentDecoders; using SqlDatabaseVectorSearch.Data; using SqlDatabaseVectorSearch.Models; using SqlDatabaseVectorSearch.Settings; +using SqlDatabaseVectorSearch.Workflows; using ChatResponse = SqlDatabaseVectorSearch.Models.ChatResponse; using Entities = SqlDatabaseVectorSearch.Data.Entities; namespace SqlDatabaseVectorSearch.Services; -public partial class VectorSearchService(IServiceProvider serviceProvider, ApplicationDbContext dbContext, DocumentService documentService, IEmbeddingGenerator> embeddingGenerator, TokenizerService tokenizerService, ChatService chatService, TimeProvider timeProvider, IOptions appSettingsOptions, ILogger logger) +public partial class VectorSearchService([FromKeyedServices("EmbeddingWorkflow")] Workflow workflow, ApplicationDbContext dbContext, IEmbeddingGenerator> embeddingGenerator, TokenizerService tokenizerService, ChatService chatService, TimeProvider timeProvider, IOptions appSettingsOptions, ILogger logger) { private readonly AppSettings appSettings = appSettingsOptions.Value; - public async Task ImportAsync(Stream stream, string name, string contentType, Guid? documentId, CancellationToken cancellationToken = default) + public async Task ImportAsync(FormFileEmbeddingRequest request, CancellationToken cancellationToken = default) { - // Extract the contents of the file. - var decoder = serviceProvider.GetKeyedService(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported."); - var chunks = await decoder.DecodeAsync(stream, contentType, cancellationToken); - var chunkContents = chunks.Select(p => p.Content).ToList(); + await using var run = await InProcessExecution.RunAsync(workflow, request, cancellationToken: cancellationToken); + var events = run.NewEvents.ToList(); - // We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis). - var tokenCount = tokenizerService.CountEmbeddingTokens(string.Join(" ", chunkContents)); - - var strategy = dbContext.Database.CreateExecutionStrategy(); - var document = await strategy.ExecuteAsync(async (cancellationToken) => + var exception = events.OfType().Select(e => e.Exception).FirstOrDefault(); + if (exception is not null) { - await dbContext.Database.BeginTransactionAsync(cancellationToken); + throw exception; + } - if (documentId.HasValue) - { - // If the user is importing a document that already exists, delete the previous one. - await documentService.DeleteAsync(documentId.Value, cancellationToken); - } - - var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() }; - dbContext.Documents.Add(document); - - // Process paragraphs in batches. - var embeddings = new List>(); - foreach (var batch in chunkContents.Chunk(appSettings.EmbeddingBatchSize)) - { - logger.LogDebug("Processing batch of {Count} chunks for embedding generation...", batch.Length); - - // Generate embeddings for this batch. - var batchEmbeddings = await embeddingGenerator.GenerateAsync(batch, cancellationToken: cancellationToken); - embeddings.AddRange(batchEmbeddings); - } - - // Save the document chunks and the corresponding embedding in the database. - foreach (var (index, embedding) in embeddings.Index()) - { - var chunk = chunks.ElementAt(index); - logger.LogDebug("Storing a chunk of {TokenCount} tokens.", tokenizerService.CountEmbeddingTokens(chunk.Content)); - - var documentChunk = new Entities.DocumentChunk - { - Document = document, - Index = index, - PageNumber = chunk.PageNumber, - IndexOnPage = chunk.IndexOnPage, - Content = chunk.Content, - Embedding = new SqlVector(embedding.Vector) - }; - - dbContext.DocumentChunks.Add(documentChunk); - } - - await dbContext.SaveChangesAsync(cancellationToken); - await dbContext.Database.CommitTransactionAsync(cancellationToken); - - return document; - }, cancellationToken); - - return new(document.Id, tokenCount); + var result = events.OfType().Select(e => e.Data).OfType().First(); + return result; } public async Task AskQuestionAsync(Question question, bool reformulate = true, CancellationToken cancellationToken = default) diff --git a/SqlDatabaseVectorSearch/SqlDatabaseVectorSearch.csproj b/SqlDatabaseVectorSearch/SqlDatabaseVectorSearch.csproj index 42361e4..47ca899 100644 --- a/SqlDatabaseVectorSearch/SqlDatabaseVectorSearch.csproj +++ b/SqlDatabaseVectorSearch/SqlDatabaseVectorSearch.csproj @@ -14,6 +14,11 @@ + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + all diff --git a/SqlDatabaseVectorSearch/Workflows/FormFileToEmbeddingRequestExecutor.cs b/SqlDatabaseVectorSearch/Workflows/FormFileToEmbeddingRequestExecutor.cs new file mode 100644 index 0000000..4e93aa2 --- /dev/null +++ b/SqlDatabaseVectorSearch/Workflows/FormFileToEmbeddingRequestExecutor.cs @@ -0,0 +1,18 @@ +using Microsoft.Agents.AI.Workflows; + +namespace SqlDatabaseVectorSearch.Workflows; + +public partial class FormFileToEmbeddingRequestExecutor() : Executor(nameof(FormFileToEmbeddingRequestExecutor)) +{ + [MessageHandler] + private ValueTask HandleAsync(FormFileEmbeddingRequest request, IWorkflowContext context, CancellationToken cancellationToken) + { + // Note: file.ContentType is not 100% reliable (for example, for markdown file). + var embeddingRequest = new EmbeddingRequest(request.File.OpenReadStream(), Path.GetFileName(request.File.FileName), MimeMapping.MimeUtility.GetMimeMapping(request.File.FileName), request.DocumentId); + return ValueTask.FromResult(embeddingRequest); + } +} + +public record class FormFileEmbeddingRequest(IFormFile File, Guid? DocumentId); + +public record class EmbeddingRequest(Stream Content, string FileName, string ContentType, Guid? DocumentId); \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Workflows/GenerateEmbeddingExecutor.cs b/SqlDatabaseVectorSearch/Workflows/GenerateEmbeddingExecutor.cs new file mode 100644 index 0000000..a365855 --- /dev/null +++ b/SqlDatabaseVectorSearch/Workflows/GenerateEmbeddingExecutor.cs @@ -0,0 +1,41 @@ +using Microsoft.Agents.AI.Workflows; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Options; +using SqlDatabaseVectorSearch.ContentDecoders; +using SqlDatabaseVectorSearch.Models; +using SqlDatabaseVectorSearch.Services; +using SqlDatabaseVectorSearch.Settings; + +namespace SqlDatabaseVectorSearch.Workflows; + +public partial class GenerateEmbeddingExecutor(IServiceProvider serviceProvider, IEmbeddingGenerator> embeddingGenerator, TokenizerService tokenizerService, IOptions appSettingsOptions, ILogger logger) : Executor(nameof(GenerateEmbeddingExecutor)) +{ + private readonly AppSettings appSettings = appSettingsOptions.Value; + + [MessageHandler] + private async ValueTask HandleAsync(EmbeddingRequest request, IWorkflowContext context, CancellationToken cancellationToken) + { + // Extract the contents of the file. + var decoder = serviceProvider.GetKeyedService(request.ContentType) ?? throw new NotSupportedException($"Content type '{request.ContentType}' is not supported."); + var chunks = await decoder.DecodeAsync(request.Content, request.ContentType, cancellationToken); + var chunkContents = chunks.Select(p => p.Content).ToList(); + + // We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis). + var tokenCount = tokenizerService.CountEmbeddingTokens(string.Join(" ", chunkContents)); + + // Process paragraphs in batches. + var embeddings = new List>(); + foreach (var batch in chunkContents.Chunk(appSettings.EmbeddingBatchSize)) + { + logger.LogDebug("Processing batch of {Count} chunks for embedding generation...", batch.Length); + + // Generate embeddings for this batch. + var batchEmbeddings = await embeddingGenerator.GenerateAsync(batch, cancellationToken: cancellationToken); + embeddings.AddRange(batchEmbeddings); + } + + return new EmbeddingResponse(request, chunks, embeddings, tokenCount); + } +} + +public record class EmbeddingResponse(EmbeddingRequest Request, IEnumerable Chunks, IEnumerable> Embeddings, int TokenCount); \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Workflows/StoreEmbeddingExecutor.cs b/SqlDatabaseVectorSearch/Workflows/StoreEmbeddingExecutor.cs new file mode 100644 index 0000000..c0f0155 --- /dev/null +++ b/SqlDatabaseVectorSearch/Workflows/StoreEmbeddingExecutor.cs @@ -0,0 +1,58 @@ +using Microsoft.Agents.AI.Workflows; +using Microsoft.Data.SqlTypes; +using Microsoft.EntityFrameworkCore; +using SqlDatabaseVectorSearch.Data; +using SqlDatabaseVectorSearch.Services; +using Entities = SqlDatabaseVectorSearch.Data.Entities; + +namespace SqlDatabaseVectorSearch.Workflows; + +public partial class StoreEmbeddingExecutor(ApplicationDbContext dbContext, DocumentService documentService, TokenizerService tokenizerService, TimeProvider timeProvider, ILogger logger) : Executor(nameof(StoreEmbeddingExecutor)) +{ + [MessageHandler] + private async ValueTask HandleAsync(EmbeddingResponse embeddingData, IWorkflowContext context, CancellationToken cancellationToken) + { + var strategy = dbContext.Database.CreateExecutionStrategy(); + var document = await strategy.ExecuteAsync(async (cancellationToken) => + { + await dbContext.Database.BeginTransactionAsync(cancellationToken); + + if (embeddingData.Request.DocumentId.HasValue) + { + // If the user is importing a document that already exists, delete the previous one. + await documentService.DeleteAsync(embeddingData.Request.DocumentId.Value, cancellationToken); + } + + var document = new Entities.Document { Id = embeddingData.Request.DocumentId.GetValueOrDefault(), Name = embeddingData.Request.FileName, CreationDate = timeProvider.GetUtcNow() }; + dbContext.Documents.Add(document); + + // Save the document chunks and the corresponding embedding in the database. + foreach (var (index, embedding) in embeddingData.Embeddings.Index()) + { + var chunk = embeddingData.Chunks.ElementAt(index); + logger.LogDebug("Storing a chunk of {TokenCount} tokens.", tokenizerService.CountEmbeddingTokens(chunk.Content)); + + var documentChunk = new Entities.DocumentChunk + { + Document = document, + Index = index, + PageNumber = chunk.PageNumber, + IndexOnPage = chunk.IndexOnPage, + Content = chunk.Content, + Embedding = new SqlVector(embedding.Vector) + }; + + dbContext.DocumentChunks.Add(documentChunk); + } + + await dbContext.SaveChangesAsync(cancellationToken); + await dbContext.Database.CommitTransactionAsync(cancellationToken); + + return document; + }, cancellationToken); + + return new(document.Id, embeddingData.TokenCount); + } +} + +public record class StoreEmbeddingResponse(Guid DocumentId, int TokenCount); \ No newline at end of file