mirror of
https://github.com/marcominerva/SqlDatabaseVectorSearch.git
synced 2026-06-20 12:23:10 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8a761ddcd7 | |||
| 14f983307e | |||
| 5624f73640 |
@@ -177,7 +177,7 @@ else
|
|||||||
var vectorSearchService = scope.ServiceProvider.GetRequiredService<VectorSearchService>();
|
var vectorSearchService = scope.ServiceProvider.GetRequiredService<VectorSearchService>();
|
||||||
|
|
||||||
var documentId = string.IsNullOrWhiteSpace(Model.DocumentId) ? null : (Guid?)Guid.Parse(Model.DocumentId);
|
var documentId = string.IsNullOrWhiteSpace(Model.DocumentId) ? null : (Guid?)Guid.Parse(Model.DocumentId);
|
||||||
await vectorSearchService.ImportAsync(stream, fileName, MimeUtility.GetMimeMapping(fileName), documentId);
|
//await vectorSearchService.ImportAsync(stream, fileName, MimeUtility.GetMimeMapping(fileName), documentId);
|
||||||
|
|
||||||
ToastService.Notify(await CreateToastMessageAsync(ToastType.Success, "Upload document", $"The document {fileName} has been successfully uploaded and indexed."));
|
ToastService.Notify(await CreateToastMessageAsync(ToastType.Success, "Upload document", $"The document {fileName} has been successfully uploaded and indexed."));
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
|
|
||||||
using System.ComponentModel;
|
using System.ComponentModel;
|
||||||
using Microsoft.AspNetCore.Http.HttpResults;
|
using Microsoft.AspNetCore.Http.HttpResults;
|
||||||
using MimeMapping;
|
|
||||||
using SqlDatabaseVectorSearch.Models;
|
using SqlDatabaseVectorSearch.Models;
|
||||||
using SqlDatabaseVectorSearch.Services;
|
using SqlDatabaseVectorSearch.Services;
|
||||||
|
using SqlDatabaseVectorSearch.Workflows;
|
||||||
|
|
||||||
namespace SqlDatabaseVectorSearch.Endpoints;
|
namespace SqlDatabaseVectorSearch.Endpoints;
|
||||||
|
|
||||||
@@ -23,12 +23,8 @@ public class DocumentEndpoints : IEndpointRouteHandlerBuilder
|
|||||||
documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchService vectorSearchService, CancellationToken cancellationToken,
|
documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchService vectorSearchService, CancellationToken cancellationToken,
|
||||||
[Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) =>
|
[Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) =>
|
||||||
{
|
{
|
||||||
using var stream = file.OpenReadStream();
|
var result = await vectorSearchService.ImportAsync(new FormFileEmbeddingRequest(file, documentId), cancellationToken);
|
||||||
|
return TypedResults.Ok(result);
|
||||||
// Note: file.ContentType is not 100% reliable (for example, for markdown file).
|
|
||||||
var response = await vectorSearchService.ImportAsync(stream, file.FileName, MimeUtility.GetMimeMapping(file.FileName), documentId, cancellationToken);
|
|
||||||
|
|
||||||
return TypedResults.Ok(response);
|
|
||||||
})
|
})
|
||||||
.DisableAntiforgery()
|
.DisableAntiforgery()
|
||||||
.ProducesProblem(StatusCodes.Status400BadRequest)
|
.ProducesProblem(StatusCodes.Status400BadRequest)
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
namespace SqlDatabaseVectorSearch.Models;
|
|
||||||
|
|
||||||
public record class ImportDocumentResponse(Guid DocumentId, int EmbeddingTokenCount);
|
|
||||||
@@ -1,8 +1,14 @@
|
|||||||
|
using System.ClientModel;
|
||||||
using System.Net.Mime;
|
using System.Net.Mime;
|
||||||
using System.Text.Json.Serialization;
|
using System.Text.Json.Serialization;
|
||||||
using FluentValidation;
|
using FluentValidation;
|
||||||
|
using Microsoft.Agents.AI.Hosting;
|
||||||
|
using Microsoft.Agents.AI.Workflows;
|
||||||
using Microsoft.EntityFrameworkCore;
|
using Microsoft.EntityFrameworkCore;
|
||||||
|
using Microsoft.Extensions.AI;
|
||||||
using Microsoft.SemanticKernel;
|
using Microsoft.SemanticKernel;
|
||||||
|
using OpenAI;
|
||||||
|
using OpenAI.Responses;
|
||||||
using SqlDatabaseVectorSearch.Components;
|
using SqlDatabaseVectorSearch.Components;
|
||||||
using SqlDatabaseVectorSearch.ContentDecoders;
|
using SqlDatabaseVectorSearch.ContentDecoders;
|
||||||
using SqlDatabaseVectorSearch.Data;
|
using SqlDatabaseVectorSearch.Data;
|
||||||
@@ -10,6 +16,7 @@ using SqlDatabaseVectorSearch.Extensions;
|
|||||||
using SqlDatabaseVectorSearch.Services;
|
using SqlDatabaseVectorSearch.Services;
|
||||||
using SqlDatabaseVectorSearch.Settings;
|
using SqlDatabaseVectorSearch.Settings;
|
||||||
using SqlDatabaseVectorSearch.TextChunkers;
|
using SqlDatabaseVectorSearch.TextChunkers;
|
||||||
|
using SqlDatabaseVectorSearch.Workflows;
|
||||||
using TinyHelpers.AspNetCore.Extensions;
|
using TinyHelpers.AspNetCore.Extensions;
|
||||||
using TinyHelpers.AspNetCore.OpenApi;
|
using TinyHelpers.AspNetCore.OpenApi;
|
||||||
|
|
||||||
@@ -54,10 +61,29 @@ builder.Services.ConfigureHttpClientDefaults(configure =>
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
builder.Services.AddSingleton(_ =>
|
||||||
|
{
|
||||||
|
var embeddingClient = new OpenAIClient(new ApiKeyCredential(aiSettings.Embedding.ApiKey), new()
|
||||||
|
{
|
||||||
|
Endpoint = new(aiSettings.Embedding.Endpoint),
|
||||||
|
}).GetEmbeddingClient(aiSettings.Embedding.Deployment).AsIEmbeddingGenerator(aiSettings.Embedding.Dimensions);
|
||||||
|
|
||||||
|
return embeddingClient;
|
||||||
|
});
|
||||||
|
|
||||||
|
builder.Services.AddChatClient(_ =>
|
||||||
|
{
|
||||||
|
var chatClient = new OpenAIClient(new ApiKeyCredential(aiSettings.ChatCompletion.ApiKey), new()
|
||||||
|
{
|
||||||
|
Endpoint = new(aiSettings.ChatCompletion.Endpoint),
|
||||||
|
}).GetResponsesClient().AsIChatClientWithStoredOutputDisabled(aiSettings.ChatCompletion.Deployment);
|
||||||
|
|
||||||
|
return chatClient;
|
||||||
|
});
|
||||||
|
|
||||||
// Semantic Kernel is used to generate embeddings and to reformulate questions taking into account all the previous interactions,
|
// Semantic Kernel is used to generate embeddings and to reformulate questions taking into account all the previous interactions,
|
||||||
// so that embeddings themselves can be generated more accurately.
|
// so that embeddings themselves can be generated more accurately.
|
||||||
builder.Services.AddKernel()
|
builder.Services.AddKernel()
|
||||||
.AddAzureOpenAIEmbeddingGenerator(aiSettings.Embedding.Deployment, aiSettings.Embedding.Endpoint, aiSettings.Embedding.ApiKey, modelId: aiSettings.Embedding.ModelId, dimensions: aiSettings.Embedding.Dimensions)
|
|
||||||
.AddAzureOpenAIChatCompletion(aiSettings.ChatCompletion.Deployment, aiSettings.ChatCompletion.Endpoint, aiSettings.ChatCompletion.ApiKey, modelId: aiSettings.ChatCompletion.ModelId);
|
.AddAzureOpenAIChatCompletion(aiSettings.ChatCompletion.Deployment, aiSettings.ChatCompletion.Endpoint, aiSettings.ChatCompletion.ApiKey, modelId: aiSettings.ChatCompletion.ModelId);
|
||||||
|
|
||||||
builder.Services.AddKeyedSingleton<IContentDecoder, PdfContentDecoder>(MediaTypeNames.Application.Pdf);
|
builder.Services.AddKeyedSingleton<IContentDecoder, PdfContentDecoder>(MediaTypeNames.Application.Pdf);
|
||||||
@@ -74,6 +100,25 @@ builder.Services.AddSingleton<ChatService>();
|
|||||||
builder.Services.AddScoped<DocumentService>();
|
builder.Services.AddScoped<DocumentService>();
|
||||||
builder.Services.AddScoped<VectorSearchService>();
|
builder.Services.AddScoped<VectorSearchService>();
|
||||||
|
|
||||||
|
builder.Services.AddSingleton<FormFileToEmbeddingRequestExecutor>();
|
||||||
|
builder.Services.AddSingleton<GenerateEmbeddingExecutor>();
|
||||||
|
builder.Services.AddScoped<StoreEmbeddingExecutor>(); // This executor is registered as scoped because it uses the DbContext, which is also scoped.
|
||||||
|
|
||||||
|
builder.AddWorkflow("EmbeddingWorkflow", (services, key) =>
|
||||||
|
{
|
||||||
|
var formfileToConversionRequestExecutor = services.GetRequiredService<FormFileToEmbeddingRequestExecutor>();
|
||||||
|
var generateEmbeddingExecutor = services.GetRequiredService<GenerateEmbeddingExecutor>();
|
||||||
|
var storeEmbeddingExecutor = services.GetRequiredService<StoreEmbeddingExecutor>();
|
||||||
|
|
||||||
|
var workflow = new WorkflowBuilder(formfileToConversionRequestExecutor).WithName(key)
|
||||||
|
.AddEdge(formfileToConversionRequestExecutor, generateEmbeddingExecutor)
|
||||||
|
.AddEdge(generateEmbeddingExecutor, storeEmbeddingExecutor)
|
||||||
|
.WithOutputFrom(storeEmbeddingExecutor)
|
||||||
|
.Build(validateOrphans: true);
|
||||||
|
|
||||||
|
return workflow;
|
||||||
|
}, ServiceLifetime.Scoped);
|
||||||
|
|
||||||
builder.Services.AddOpenApi(options =>
|
builder.Services.AddOpenApi(options =>
|
||||||
{
|
{
|
||||||
options.RemoveServerList();
|
options.RemoveServerList();
|
||||||
|
|||||||
@@ -2,84 +2,37 @@
|
|||||||
using System.Runtime.CompilerServices;
|
using System.Runtime.CompilerServices;
|
||||||
using System.Text;
|
using System.Text;
|
||||||
using System.Text.RegularExpressions;
|
using System.Text.RegularExpressions;
|
||||||
|
using Microsoft.Agents.AI.Workflows;
|
||||||
using Microsoft.Data.SqlTypes;
|
using Microsoft.Data.SqlTypes;
|
||||||
using Microsoft.EntityFrameworkCore;
|
using Microsoft.EntityFrameworkCore;
|
||||||
using Microsoft.Extensions.AI;
|
using Microsoft.Extensions.AI;
|
||||||
using Microsoft.Extensions.Options;
|
using Microsoft.Extensions.Options;
|
||||||
using SqlDatabaseVectorSearch.ContentDecoders;
|
|
||||||
using SqlDatabaseVectorSearch.Data;
|
using SqlDatabaseVectorSearch.Data;
|
||||||
using SqlDatabaseVectorSearch.Models;
|
using SqlDatabaseVectorSearch.Models;
|
||||||
using SqlDatabaseVectorSearch.Settings;
|
using SqlDatabaseVectorSearch.Settings;
|
||||||
|
using SqlDatabaseVectorSearch.Workflows;
|
||||||
using ChatResponse = SqlDatabaseVectorSearch.Models.ChatResponse;
|
using ChatResponse = SqlDatabaseVectorSearch.Models.ChatResponse;
|
||||||
using Entities = SqlDatabaseVectorSearch.Data.Entities;
|
using Entities = SqlDatabaseVectorSearch.Data.Entities;
|
||||||
|
|
||||||
namespace SqlDatabaseVectorSearch.Services;
|
namespace SqlDatabaseVectorSearch.Services;
|
||||||
|
|
||||||
public partial class VectorSearchService(IServiceProvider serviceProvider, ApplicationDbContext dbContext, DocumentService documentService, IEmbeddingGenerator<string, Embedding<float>> embeddingGenerator, TokenizerService tokenizerService, ChatService chatService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions, ILogger<VectorSearchService> logger)
|
public partial class VectorSearchService([FromKeyedServices("EmbeddingWorkflow")] Workflow workflow, ApplicationDbContext dbContext, IEmbeddingGenerator<string, Embedding<float>> embeddingGenerator, TokenizerService tokenizerService, ChatService chatService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions, ILogger<VectorSearchService> logger)
|
||||||
{
|
{
|
||||||
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
||||||
|
|
||||||
public async Task<ImportDocumentResponse> ImportAsync(Stream stream, string name, string contentType, Guid? documentId, CancellationToken cancellationToken = default)
|
public async Task<StoreEmbeddingResponse> ImportAsync(FormFileEmbeddingRequest request, CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
// Extract the contents of the file.
|
await using var run = await InProcessExecution.RunAsync(workflow, request, cancellationToken: cancellationToken);
|
||||||
var decoder = serviceProvider.GetKeyedService<IContentDecoder>(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported.");
|
var events = run.NewEvents.ToList();
|
||||||
var chunks = await decoder.DecodeAsync(stream, contentType, cancellationToken);
|
|
||||||
var chunkContents = chunks.Select(p => p.Content).ToList();
|
|
||||||
|
|
||||||
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
|
var exception = events.OfType<WorkflowErrorEvent>().Select(e => e.Exception).FirstOrDefault();
|
||||||
var tokenCount = tokenizerService.CountEmbeddingTokens(string.Join(" ", chunkContents));
|
if (exception is not null)
|
||||||
|
|
||||||
var strategy = dbContext.Database.CreateExecutionStrategy();
|
|
||||||
var document = await strategy.ExecuteAsync(async (cancellationToken) =>
|
|
||||||
{
|
{
|
||||||
await dbContext.Database.BeginTransactionAsync(cancellationToken);
|
throw exception;
|
||||||
|
}
|
||||||
|
|
||||||
if (documentId.HasValue)
|
var result = events.OfType<WorkflowOutputEvent>().Select(e => e.Data).OfType<StoreEmbeddingResponse>().First();
|
||||||
{
|
return result;
|
||||||
// If the user is importing a document that already exists, delete the previous one.
|
|
||||||
await documentService.DeleteAsync(documentId.Value, cancellationToken);
|
|
||||||
}
|
|
||||||
|
|
||||||
var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() };
|
|
||||||
dbContext.Documents.Add(document);
|
|
||||||
|
|
||||||
// Process paragraphs in batches.
|
|
||||||
var embeddings = new List<Embedding<float>>();
|
|
||||||
foreach (var batch in chunkContents.Chunk(appSettings.EmbeddingBatchSize))
|
|
||||||
{
|
|
||||||
logger.LogDebug("Processing batch of {Count} chunks for embedding generation...", batch.Length);
|
|
||||||
|
|
||||||
// Generate embeddings for this batch.
|
|
||||||
var batchEmbeddings = await embeddingGenerator.GenerateAsync(batch, cancellationToken: cancellationToken);
|
|
||||||
embeddings.AddRange(batchEmbeddings);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save the document chunks and the corresponding embedding in the database.
|
|
||||||
foreach (var (index, embedding) in embeddings.Index())
|
|
||||||
{
|
|
||||||
var chunk = chunks.ElementAt(index);
|
|
||||||
logger.LogDebug("Storing a chunk of {TokenCount} tokens.", tokenizerService.CountEmbeddingTokens(chunk.Content));
|
|
||||||
|
|
||||||
var documentChunk = new Entities.DocumentChunk
|
|
||||||
{
|
|
||||||
Document = document,
|
|
||||||
Index = index,
|
|
||||||
PageNumber = chunk.PageNumber,
|
|
||||||
IndexOnPage = chunk.IndexOnPage,
|
|
||||||
Content = chunk.Content,
|
|
||||||
Embedding = new SqlVector<float>(embedding.Vector)
|
|
||||||
};
|
|
||||||
|
|
||||||
dbContext.DocumentChunks.Add(documentChunk);
|
|
||||||
}
|
|
||||||
|
|
||||||
await dbContext.SaveChangesAsync(cancellationToken);
|
|
||||||
await dbContext.Database.CommitTransactionAsync(cancellationToken);
|
|
||||||
|
|
||||||
return document;
|
|
||||||
}, cancellationToken);
|
|
||||||
|
|
||||||
return new(document.Id, tokenCount);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true, CancellationToken cancellationToken = default)
|
public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true, CancellationToken cancellationToken = default)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
<TargetFramework>net10.0</TargetFramework>
|
<TargetFramework>net10.0</TargetFramework>
|
||||||
<ImplicitUsings>enable</ImplicitUsings>
|
<ImplicitUsings>enable</ImplicitUsings>
|
||||||
<Nullable>enable</Nullable>
|
<Nullable>enable</Nullable>
|
||||||
<NoWarn>$(NoWarn);SKEXP0010;SKEXP0050</NoWarn>
|
<NoWarn>$(NoWarn);SKEXP0010;SKEXP0050;OPENAI001;MAAI001</NoWarn>
|
||||||
</PropertyGroup>
|
</PropertyGroup>
|
||||||
|
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
@@ -12,7 +12,18 @@
|
|||||||
<PackageReference Include="DocumentFormat.OpenXml" Version="3.5.1" />
|
<PackageReference Include="DocumentFormat.OpenXml" Version="3.5.1" />
|
||||||
<PackageReference Include="EntityFrameworkCore.Exceptions.SqlServer" Version="10.0.1" />
|
<PackageReference Include="EntityFrameworkCore.Exceptions.SqlServer" Version="10.0.1" />
|
||||||
<PackageReference Include="FluentValidation.DependencyInjectionExtensions" Version="12.1.1" />
|
<PackageReference Include="FluentValidation.DependencyInjectionExtensions" Version="12.1.1" />
|
||||||
|
<PackageReference Include="Microsoft.Agents.AI.Hosting" Version="1.10.0-preview.260610.1" />
|
||||||
|
<PackageReference Include="Microsoft.Agents.AI.OpenAI" Version="1.10.0" />
|
||||||
|
<PackageReference Include="Microsoft.Agents.AI.Workflows" Version="1.10.0" />
|
||||||
|
<PackageReference Include="Microsoft.Agents.AI.Workflows.Generators" Version="1.10.0">
|
||||||
|
<PrivateAssets>all</PrivateAssets>
|
||||||
|
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
|
||||||
|
</PackageReference>
|
||||||
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="10.0.9" />
|
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="10.0.9" />
|
||||||
|
<PackageReference Include="Microsoft.EntityFrameworkCore.Design" Version="10.0.9">
|
||||||
|
<PrivateAssets>all</PrivateAssets>
|
||||||
|
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
|
||||||
|
</PackageReference>
|
||||||
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="10.0.9" />
|
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="10.0.9" />
|
||||||
<PackageReference Include="Microsoft.EntityFrameworkCore.Tools" Version="10.0.9">
|
<PackageReference Include="Microsoft.EntityFrameworkCore.Tools" Version="10.0.9">
|
||||||
<PrivateAssets>all</PrivateAssets>
|
<PrivateAssets>all</PrivateAssets>
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
using Microsoft.Extensions.Options;
|
using Microsoft.Extensions.Options;
|
||||||
using Microsoft.SemanticKernel.Text;
|
|
||||||
using SqlDatabaseVectorSearch.Services;
|
using SqlDatabaseVectorSearch.Services;
|
||||||
using SqlDatabaseVectorSearch.Settings;
|
using SqlDatabaseVectorSearch.Settings;
|
||||||
|
using SqlDatabaseVectorSearch.TextChunkers.Implementations;
|
||||||
|
|
||||||
namespace SqlDatabaseVectorSearch.TextChunkers;
|
namespace SqlDatabaseVectorSearch.TextChunkers;
|
||||||
|
|
||||||
@@ -11,8 +11,8 @@ public class DefaultTextChunker(TokenizerService tokenizerService, IOptions<AppS
|
|||||||
|
|
||||||
public IList<string> Split(string text)
|
public IList<string> Split(string text)
|
||||||
{
|
{
|
||||||
var lines = TextChunker.SplitPlainTextLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountEmbeddingTokens);
|
var lines = PlainTextChunker.SplitPlainTextLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountEmbeddingTokens);
|
||||||
var paragraphs = TextChunker.SplitPlainTextParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountEmbeddingTokens);
|
var paragraphs = PlainTextChunker.SplitPlainTextParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountEmbeddingTokens);
|
||||||
|
|
||||||
return paragraphs;
|
return paragraphs;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,372 @@
|
|||||||
|
using System.Diagnostics;
|
||||||
|
using System.Text;
|
||||||
|
|
||||||
|
namespace SqlDatabaseVectorSearch.TextChunkers.Implementations;
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Split text in chunks, attempting to leave meaning intact.
|
||||||
|
/// For plain text, split looking at new lines first, then periods, and so on.
|
||||||
|
/// For markdown, split looking at punctuation first, and so on.
|
||||||
|
/// </summary>
|
||||||
|
internal static class PlainTextChunker
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Represents a list of strings with token count.
|
||||||
|
/// Used to reduce the number of calls to the tokenizer.
|
||||||
|
/// </summary>
|
||||||
|
private sealed class StringListWithTokenCount(TokenCounter? tokenCounter)
|
||||||
|
{
|
||||||
|
private readonly TokenCounter? tokenCounter = tokenCounter;
|
||||||
|
|
||||||
|
public void Add(string value) => Values.Add((value, tokenCounter is null ? GetDefaultTokenCount(value.Length) : tokenCounter(value)));
|
||||||
|
|
||||||
|
public void Add(string value, int tokenCount) => Values.Add((value, tokenCount));
|
||||||
|
|
||||||
|
public void AddRange(StringListWithTokenCount range) => Values.AddRange(range.Values);
|
||||||
|
|
||||||
|
public void RemoveRange(int index, int count) => Values.RemoveRange(index, count);
|
||||||
|
|
||||||
|
public int Count => Values.Count;
|
||||||
|
|
||||||
|
public List<string> ToStringList() => Values.Select(v => v.Value).ToList();
|
||||||
|
|
||||||
|
private List<(string Value, int TokenCount)> Values { get; } = [];
|
||||||
|
|
||||||
|
public string ValueAt(int i) => Values[i].Value;
|
||||||
|
|
||||||
|
public int TokenCountAt(int i) => Values[i].TokenCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Delegate for counting tokens in a string.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="input">The input string to count tokens in.</param>
|
||||||
|
/// <returns>The number of tokens in the input string.</returns>
|
||||||
|
public delegate int TokenCounter(string input);
|
||||||
|
|
||||||
|
private static readonly string?[] plainTextSplitOptions = ["\n", ".。.", "?!", ";", ":", ",,、", ")]}", " ", "-", null];
|
||||||
|
private static readonly string?[] markdownSplitOptions = [".\u3002\uFF0E", "?!", ";", ":", ",\uFF0C\u3001", ")]}", " ", "-", "\n\r", null];
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Split plain text into lines.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">Text to split</param>
|
||||||
|
/// <param name="maxTokensPerLine">Maximum number of tokens per line.</param>
|
||||||
|
/// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
|
||||||
|
/// <returns>List of lines.</returns>
|
||||||
|
public static List<string> SplitPlainTextLines(string text, int maxTokensPerLine, TokenCounter? tokenCounter = null)
|
||||||
|
{
|
||||||
|
ArgumentNullException.ThrowIfNull(text);
|
||||||
|
ValidateMaxTokens(maxTokensPerLine, nameof(maxTokensPerLine));
|
||||||
|
|
||||||
|
return InternalSplitLines(text, maxTokensPerLine, trim: true, plainTextSplitOptions, tokenCounter);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Split markdown text into lines.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="text">Text to split</param>
|
||||||
|
/// <param name="maxTokensPerLine">Maximum number of tokens per line.</param>
|
||||||
|
/// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
|
||||||
|
/// <returns>List of lines.</returns>
|
||||||
|
public static List<string> SplitMarkdownLines(string text, int maxTokensPerLine, TokenCounter? tokenCounter = null)
|
||||||
|
{
|
||||||
|
ArgumentNullException.ThrowIfNull(text);
|
||||||
|
ValidateMaxTokens(maxTokensPerLine, nameof(maxTokensPerLine));
|
||||||
|
|
||||||
|
return InternalSplitLines(text, maxTokensPerLine, trim: true, markdownSplitOptions, tokenCounter);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Split plain text into paragraphs.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="lines">Lines of text.</param>
|
||||||
|
/// <param name="maxTokensPerParagraph">Maximum number of tokens per paragraph.</param>
|
||||||
|
/// <param name="overlapTokens">Number of tokens to overlap between paragraphs.</param>
|
||||||
|
/// <param name="chunkHeader">Text to be prepended to each individual chunk.</param>
|
||||||
|
/// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
|
||||||
|
/// <returns>List of paragraphs.</returns>
|
||||||
|
public static List<string> SplitPlainTextParagraphs(IEnumerable<string> lines, int maxTokensPerParagraph, int overlapTokens = 0, string? chunkHeader = null, TokenCounter? tokenCounter = null)
|
||||||
|
=> InternalSplitTextParagraphs(lines, maxTokensPerParagraph, overlapTokens, chunkHeader,
|
||||||
|
static (text, maxTokens, tokenCounter) => InternalSplitLines(text, maxTokens, trim: false, plainTextSplitOptions, tokenCounter), tokenCounter);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Split markdown text into paragraphs.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="lines">Lines of text.</param>
|
||||||
|
/// <param name="maxTokensPerParagraph">Maximum number of tokens per paragraph.</param>
|
||||||
|
/// <param name="overlapTokens">Number of tokens to overlap between paragraphs.</param>
|
||||||
|
/// <param name="chunkHeader">Text to be prepended to each individual chunk.</param>
|
||||||
|
/// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
|
||||||
|
/// <returns>List of paragraphs.</returns>
|
||||||
|
public static List<string> SplitMarkdownParagraphs(IEnumerable<string> lines, int maxTokensPerParagraph, int overlapTokens = 0, string? chunkHeader = null, TokenCounter? tokenCounter = null)
|
||||||
|
=> InternalSplitTextParagraphs(lines, maxTokensPerParagraph, overlapTokens, chunkHeader,
|
||||||
|
static (text, maxTokens, tokenCounter) => InternalSplitLines(text, maxTokens, trim: false, markdownSplitOptions, tokenCounter), tokenCounter);
|
||||||
|
|
||||||
|
private static List<string> InternalSplitTextParagraphs(IEnumerable<string> lines, int maxTokensPerParagraph, int overlapTokens, string? chunkHeader, Func<string, int, TokenCounter?, List<string>> longLinesSplitter, TokenCounter? tokenCounter)
|
||||||
|
{
|
||||||
|
ArgumentNullException.ThrowIfNull(lines);
|
||||||
|
ValidateMaxTokens(maxTokensPerParagraph, nameof(maxTokensPerParagraph));
|
||||||
|
|
||||||
|
if (overlapTokens < 0)
|
||||||
|
{
|
||||||
|
throw new ArgumentOutOfRangeException(nameof(overlapTokens), "overlapTokens cannot be negative.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (maxTokensPerParagraph <= overlapTokens)
|
||||||
|
{
|
||||||
|
throw new ArgumentException("overlapTokens cannot be larger than or equal to maxTokensPerParagraph.", nameof(overlapTokens));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optimize empty inputs if we can efficiently determine they're empty.
|
||||||
|
if (lines is ICollection<string> c && c.Count == 0)
|
||||||
|
{
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
var chunkHeaderTokens = chunkHeader is { Length: > 0 } ? GetTokenCount(chunkHeader, tokenCounter) : 0;
|
||||||
|
var adjustedMaxTokensPerParagraph = maxTokensPerParagraph - overlapTokens - chunkHeaderTokens;
|
||||||
|
if (adjustedMaxTokensPerParagraph <= 0)
|
||||||
|
{
|
||||||
|
throw new ArgumentException("chunkHeader and overlapTokens must leave room for paragraph content.", nameof(chunkHeader));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split long lines first
|
||||||
|
var truncatedLines = lines.SelectMany(line => longLinesSplitter(NormalizeLineEndings(line), adjustedMaxTokensPerParagraph, tokenCounter));
|
||||||
|
|
||||||
|
var paragraphs = BuildParagraph(truncatedLines, adjustedMaxTokensPerParagraph, tokenCounter);
|
||||||
|
var processedParagraphs = ProcessParagraphs(paragraphs, adjustedMaxTokensPerParagraph, overlapTokens, chunkHeader, longLinesSplitter, tokenCounter);
|
||||||
|
|
||||||
|
return processedParagraphs;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<string> BuildParagraph(IEnumerable<string> truncatedLines, int maxTokensPerParagraph, TokenCounter? tokenCounter)
|
||||||
|
{
|
||||||
|
StringBuilder paragraphBuilder = new();
|
||||||
|
List<string> paragraphs = [];
|
||||||
|
|
||||||
|
foreach (var line in truncatedLines)
|
||||||
|
{
|
||||||
|
if (paragraphBuilder.Length > 0)
|
||||||
|
{
|
||||||
|
string? paragraph = null;
|
||||||
|
|
||||||
|
var currentCount = GetTokenCount(line, tokenCounter) + 1;
|
||||||
|
if (currentCount < maxTokensPerParagraph)
|
||||||
|
{
|
||||||
|
currentCount += tokenCounter is null ?
|
||||||
|
GetDefaultTokenCount(paragraphBuilder.Length) :
|
||||||
|
tokenCounter(paragraph = paragraphBuilder.ToString());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (currentCount >= maxTokensPerParagraph)
|
||||||
|
{
|
||||||
|
// Complete the paragraph and prepare for the next
|
||||||
|
paragraph ??= paragraphBuilder.ToString();
|
||||||
|
paragraphs.Add(paragraph.Trim());
|
||||||
|
paragraphBuilder.Clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
paragraphBuilder.AppendLine(line);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (paragraphBuilder.Length > 0)
|
||||||
|
{
|
||||||
|
// Add the final paragraph if there's anything remaining
|
||||||
|
paragraphs.Add(paragraphBuilder.ToString().Trim());
|
||||||
|
}
|
||||||
|
|
||||||
|
return paragraphs;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<string> ProcessParagraphs(List<string> paragraphs, int adjustedMaxTokensPerParagraph, int overlapTokens, string? chunkHeader, Func<string, int, TokenCounter?, List<string>> longLinesSplitter, TokenCounter? tokenCounter)
|
||||||
|
{
|
||||||
|
// distribute text more evenly in the last paragraphs when the last paragraph is too short.
|
||||||
|
if (paragraphs.Count > 1)
|
||||||
|
{
|
||||||
|
var lastParagraph = paragraphs[^1];
|
||||||
|
var secondLastParagraph = paragraphs[^2];
|
||||||
|
|
||||||
|
if (GetTokenCount(lastParagraph, tokenCounter) < adjustedMaxTokensPerParagraph / 4)
|
||||||
|
{
|
||||||
|
var mergedParagraph = $"{secondLastParagraph} {lastParagraph}";
|
||||||
|
|
||||||
|
if (GetTokenCount(mergedParagraph, tokenCounter) <= adjustedMaxTokensPerParagraph)
|
||||||
|
{
|
||||||
|
paragraphs[^2] = mergedParagraph;
|
||||||
|
paragraphs.RemoveAt(paragraphs.Count - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var processedParagraphs = new List<string>();
|
||||||
|
var paragraphStringBuilder = new StringBuilder();
|
||||||
|
|
||||||
|
for (var i = 0; i < paragraphs.Count; i++)
|
||||||
|
{
|
||||||
|
paragraphStringBuilder.Clear();
|
||||||
|
|
||||||
|
if (chunkHeader is not null)
|
||||||
|
{
|
||||||
|
paragraphStringBuilder.Append(chunkHeader);
|
||||||
|
}
|
||||||
|
|
||||||
|
var paragraph = paragraphs[i];
|
||||||
|
|
||||||
|
if (overlapTokens > 0 && i < paragraphs.Count - 1)
|
||||||
|
{
|
||||||
|
var nextParagraph = paragraphs[i + 1];
|
||||||
|
var split = longLinesSplitter(nextParagraph, overlapTokens, tokenCounter);
|
||||||
|
|
||||||
|
paragraphStringBuilder.Append(paragraph);
|
||||||
|
|
||||||
|
if (split.Count != 0)
|
||||||
|
{
|
||||||
|
paragraphStringBuilder.Append(' ').Append(split[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
paragraphStringBuilder.Append(paragraph);
|
||||||
|
}
|
||||||
|
|
||||||
|
processedParagraphs.Add(paragraphStringBuilder.ToString());
|
||||||
|
}
|
||||||
|
|
||||||
|
return processedParagraphs;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<string> InternalSplitLines(string text, int maxTokensPerLine, bool trim, string?[] splitOptions, TokenCounter? tokenCounter)
|
||||||
|
{
|
||||||
|
var result = new StringListWithTokenCount(tokenCounter);
|
||||||
|
|
||||||
|
text = NormalizeLineEndings(text);
|
||||||
|
result.Add(text);
|
||||||
|
for (var i = 0; i < splitOptions.Length; i++)
|
||||||
|
{
|
||||||
|
var count = result.Count; // track where the original input left off
|
||||||
|
var (splits2, inputWasSplit2) = Split(result, maxTokensPerLine, splitOptions[i].AsSpan(), trim, tokenCounter);
|
||||||
|
result.AddRange(splits2);
|
||||||
|
result.RemoveRange(0, count); // remove the original input
|
||||||
|
if (!inputWasSplit2)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.ToStringList();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static (StringListWithTokenCount, bool) Split(StringListWithTokenCount input, int maxTokens, ReadOnlySpan<char> separators, bool trim, TokenCounter? tokenCounter)
|
||||||
|
{
|
||||||
|
var inputWasSplit = false;
|
||||||
|
StringListWithTokenCount result = new(tokenCounter);
|
||||||
|
var count = input.Count;
|
||||||
|
for (var i = 0; i < count; i++)
|
||||||
|
{
|
||||||
|
var (splits, split) = Split(input.ValueAt(i).AsSpan(), input.ValueAt(i), maxTokens, separators, trim, tokenCounter, input.TokenCountAt(i));
|
||||||
|
result.AddRange(splits);
|
||||||
|
inputWasSplit |= split;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (result, inputWasSplit);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static (StringListWithTokenCount, bool) Split(ReadOnlySpan<char> input, string? inputString, int maxTokens, ReadOnlySpan<char> separators, bool trim, TokenCounter? tokenCounter, int inputTokenCount)
|
||||||
|
{
|
||||||
|
Debug.Assert(inputString is null || input.SequenceEqual(inputString.AsSpan()));
|
||||||
|
StringListWithTokenCount result = new(tokenCounter);
|
||||||
|
var inputWasSplit = false;
|
||||||
|
|
||||||
|
if (inputTokenCount > maxTokens)
|
||||||
|
{
|
||||||
|
inputWasSplit = true;
|
||||||
|
|
||||||
|
var half = input.Length / 2;
|
||||||
|
var cutPoint = -1;
|
||||||
|
|
||||||
|
if (separators.IsEmpty)
|
||||||
|
{
|
||||||
|
cutPoint = half;
|
||||||
|
}
|
||||||
|
else if (input.Length > 2)
|
||||||
|
{
|
||||||
|
var pos = 0;
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
var index = input[pos..^1].IndexOfAny(separators);
|
||||||
|
if (index < 0)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
index += pos;
|
||||||
|
|
||||||
|
if (Math.Abs(half - index) < Math.Abs(half - cutPoint))
|
||||||
|
{
|
||||||
|
cutPoint = index + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
pos = index + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cutPoint > 0)
|
||||||
|
{
|
||||||
|
var firstHalf = input[..cutPoint];
|
||||||
|
var secondHalf = input[cutPoint..];
|
||||||
|
if (trim)
|
||||||
|
{
|
||||||
|
firstHalf = firstHalf.Trim();
|
||||||
|
secondHalf = secondHalf.Trim();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursion
|
||||||
|
var (splits1, split1) = Split(firstHalf, null, maxTokens, separators, trim, tokenCounter, GetTokenCount(firstHalf, tokenCounter));
|
||||||
|
result.AddRange(splits1);
|
||||||
|
var (splits2, split2) = Split(secondHalf, null, maxTokens, separators, trim, tokenCounter, GetTokenCount(secondHalf, tokenCounter));
|
||||||
|
result.AddRange(splits2);
|
||||||
|
|
||||||
|
inputWasSplit = split1 || split2;
|
||||||
|
return (result, inputWasSplit);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var resultString = inputString ?? input.ToString();
|
||||||
|
var resultTokenCount = inputTokenCount;
|
||||||
|
if (trim)
|
||||||
|
{
|
||||||
|
var trimmedResult = resultString.Trim();
|
||||||
|
if (!trimmedResult.Equals(resultString, StringComparison.Ordinal))
|
||||||
|
{
|
||||||
|
resultString = trimmedResult;
|
||||||
|
resultTokenCount = GetTokenCount(resultString, tokenCounter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result.Add(resultString, resultTokenCount);
|
||||||
|
|
||||||
|
return (result, inputWasSplit);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int GetTokenCount(string input, TokenCounter? tokenCounter) => tokenCounter is null ? GetDefaultTokenCount(input.Length) : tokenCounter(input);
|
||||||
|
|
||||||
|
private static int GetTokenCount(ReadOnlySpan<char> input, TokenCounter? tokenCounter) => tokenCounter is null ? GetDefaultTokenCount(input.Length) : tokenCounter(input.ToString());
|
||||||
|
|
||||||
|
private static string NormalizeLineEndings(string text) => text.Replace("\r\n", "\n").Replace('\r', '\n');
|
||||||
|
|
||||||
|
private static void ValidateMaxTokens(int maxTokens, string parameterName)
|
||||||
|
{
|
||||||
|
if (maxTokens <= 0)
|
||||||
|
{
|
||||||
|
throw new ArgumentOutOfRangeException(parameterName, "The maximum token count must be a positive number.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int GetDefaultTokenCount(int length)
|
||||||
|
{
|
||||||
|
Debug.Assert(length >= 0);
|
||||||
|
return length == 0 ? 0 : Math.Max(1, length >> 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
using Microsoft.Extensions.Options;
|
using Microsoft.Extensions.Options;
|
||||||
using Microsoft.SemanticKernel.Text;
|
|
||||||
using SqlDatabaseVectorSearch.Services;
|
using SqlDatabaseVectorSearch.Services;
|
||||||
using SqlDatabaseVectorSearch.Settings;
|
using SqlDatabaseVectorSearch.Settings;
|
||||||
|
using SqlDatabaseVectorSearch.TextChunkers.Implementations;
|
||||||
|
|
||||||
namespace SqlDatabaseVectorSearch.TextChunkers;
|
namespace SqlDatabaseVectorSearch.TextChunkers;
|
||||||
|
|
||||||
@@ -11,8 +11,8 @@ public class MarkdownTextChunker(TokenizerService tokenizerService, IOptions<App
|
|||||||
|
|
||||||
public IList<string> Split(string text)
|
public IList<string> Split(string text)
|
||||||
{
|
{
|
||||||
var lines = TextChunker.SplitMarkDownLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountEmbeddingTokens);
|
var lines = PlainTextChunker.SplitMarkdownLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountEmbeddingTokens);
|
||||||
var paragraphs = TextChunker.SplitMarkdownParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountEmbeddingTokens);
|
var paragraphs = PlainTextChunker.SplitMarkdownParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountEmbeddingTokens);
|
||||||
|
|
||||||
return paragraphs;
|
return paragraphs;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
using Microsoft.Agents.AI.Workflows;
|
||||||
|
|
||||||
|
namespace SqlDatabaseVectorSearch.Workflows;
|
||||||
|
|
||||||
|
public partial class FormFileToEmbeddingRequestExecutor() : Executor(nameof(FormFileToEmbeddingRequestExecutor))
|
||||||
|
{
|
||||||
|
[MessageHandler]
|
||||||
|
private ValueTask<EmbeddingRequest> HandleAsync(FormFileEmbeddingRequest request, IWorkflowContext context, CancellationToken cancellationToken)
|
||||||
|
{
|
||||||
|
// Note: file.ContentType is not 100% reliable (for example, for markdown file).
|
||||||
|
var embeddingRequest = new EmbeddingRequest(request.File.OpenReadStream(), Path.GetFileName(request.File.FileName), MimeMapping.MimeUtility.GetMimeMapping(request.File.FileName), request.DocumentId);
|
||||||
|
return ValueTask.FromResult(embeddingRequest);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public record class FormFileEmbeddingRequest(IFormFile File, Guid? DocumentId);
|
||||||
|
|
||||||
|
public record class EmbeddingRequest(Stream Content, string FileName, string ContentType, Guid? DocumentId);
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
using Microsoft.Agents.AI.Workflows;
|
||||||
|
using Microsoft.Extensions.AI;
|
||||||
|
using Microsoft.Extensions.Options;
|
||||||
|
using SqlDatabaseVectorSearch.ContentDecoders;
|
||||||
|
using SqlDatabaseVectorSearch.Models;
|
||||||
|
using SqlDatabaseVectorSearch.Services;
|
||||||
|
using SqlDatabaseVectorSearch.Settings;
|
||||||
|
|
||||||
|
namespace SqlDatabaseVectorSearch.Workflows;
|
||||||
|
|
||||||
|
public partial class GenerateEmbeddingExecutor(IServiceProvider serviceProvider, IEmbeddingGenerator<string, Embedding<float>> embeddingGenerator, TokenizerService tokenizerService, IOptions<AppSettings> appSettingsOptions, ILogger<VectorSearchService> logger) : Executor(nameof(GenerateEmbeddingExecutor))
|
||||||
|
{
|
||||||
|
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
||||||
|
|
||||||
|
[MessageHandler]
|
||||||
|
private async ValueTask<EmbeddingResponse> HandleAsync(EmbeddingRequest request, IWorkflowContext context, CancellationToken cancellationToken)
|
||||||
|
{
|
||||||
|
// Extract the contents of the file.
|
||||||
|
var decoder = serviceProvider.GetKeyedService<IContentDecoder>(request.ContentType) ?? throw new NotSupportedException($"Content type '{request.ContentType}' is not supported.");
|
||||||
|
var chunks = await decoder.DecodeAsync(request.Content, request.ContentType, cancellationToken);
|
||||||
|
var chunkContents = chunks.Select(p => p.Content).ToList();
|
||||||
|
|
||||||
|
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
|
||||||
|
var tokenCount = tokenizerService.CountEmbeddingTokens(string.Join(" ", chunkContents));
|
||||||
|
|
||||||
|
// Process paragraphs in batches.
|
||||||
|
var embeddings = new List<Embedding<float>>();
|
||||||
|
foreach (var batch in chunkContents.Chunk(appSettings.EmbeddingBatchSize))
|
||||||
|
{
|
||||||
|
logger.LogDebug("Processing batch of {Count} chunks for embedding generation...", batch.Length);
|
||||||
|
|
||||||
|
// Generate embeddings for this batch.
|
||||||
|
var batchEmbeddings = await embeddingGenerator.GenerateAsync(batch, cancellationToken: cancellationToken);
|
||||||
|
embeddings.AddRange(batchEmbeddings);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new EmbeddingResponse(request, chunks, embeddings, tokenCount);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public record class EmbeddingResponse(EmbeddingRequest Request, IEnumerable<Chunk> Chunks, IEnumerable<Embedding<float>> Embeddings, int TokenCount);
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
using Microsoft.Agents.AI.Workflows;
|
||||||
|
using Microsoft.Data.SqlTypes;
|
||||||
|
using Microsoft.EntityFrameworkCore;
|
||||||
|
using SqlDatabaseVectorSearch.Data;
|
||||||
|
using SqlDatabaseVectorSearch.Services;
|
||||||
|
using Entities = SqlDatabaseVectorSearch.Data.Entities;
|
||||||
|
|
||||||
|
namespace SqlDatabaseVectorSearch.Workflows;
|
||||||
|
|
||||||
|
public partial class StoreEmbeddingExecutor(ApplicationDbContext dbContext, DocumentService documentService, TokenizerService tokenizerService, TimeProvider timeProvider, ILogger<VectorSearchService> logger) : Executor(nameof(StoreEmbeddingExecutor))
|
||||||
|
{
|
||||||
|
[MessageHandler]
|
||||||
|
private async ValueTask<StoreEmbeddingResponse> HandleAsync(EmbeddingResponse embeddingData, IWorkflowContext context, CancellationToken cancellationToken)
|
||||||
|
{
|
||||||
|
var strategy = dbContext.Database.CreateExecutionStrategy();
|
||||||
|
var document = await strategy.ExecuteAsync(async (cancellationToken) =>
|
||||||
|
{
|
||||||
|
await dbContext.Database.BeginTransactionAsync(cancellationToken);
|
||||||
|
|
||||||
|
if (embeddingData.Request.DocumentId.HasValue)
|
||||||
|
{
|
||||||
|
// If the user is importing a document that already exists, delete the previous one.
|
||||||
|
await documentService.DeleteAsync(embeddingData.Request.DocumentId.Value, cancellationToken);
|
||||||
|
}
|
||||||
|
|
||||||
|
var document = new Entities.Document { Id = embeddingData.Request.DocumentId.GetValueOrDefault(), Name = embeddingData.Request.FileName, CreationDate = timeProvider.GetUtcNow() };
|
||||||
|
dbContext.Documents.Add(document);
|
||||||
|
|
||||||
|
// Save the document chunks and the corresponding embedding in the database.
|
||||||
|
foreach (var (index, embedding) in embeddingData.Embeddings.Index())
|
||||||
|
{
|
||||||
|
var chunk = embeddingData.Chunks.ElementAt(index);
|
||||||
|
logger.LogDebug("Storing a chunk of {TokenCount} tokens.", tokenizerService.CountEmbeddingTokens(chunk.Content));
|
||||||
|
|
||||||
|
var documentChunk = new Entities.DocumentChunk
|
||||||
|
{
|
||||||
|
Document = document,
|
||||||
|
Index = index,
|
||||||
|
PageNumber = chunk.PageNumber,
|
||||||
|
IndexOnPage = chunk.IndexOnPage,
|
||||||
|
Content = chunk.Content,
|
||||||
|
Embedding = new SqlVector<float>(embedding.Vector)
|
||||||
|
};
|
||||||
|
|
||||||
|
dbContext.DocumentChunks.Add(documentChunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
await dbContext.SaveChangesAsync(cancellationToken);
|
||||||
|
await dbContext.Database.CommitTransactionAsync(cancellationToken);
|
||||||
|
|
||||||
|
return document;
|
||||||
|
}, cancellationToken);
|
||||||
|
|
||||||
|
return new(document.Id, embeddingData.TokenCount);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public record class StoreEmbeddingResponse(Guid DocumentId, int TokenCount);
|
||||||
Reference in New Issue
Block a user