3 Commits

Author SHA1 Message Date
Marco Minerva 8a761ddcd7 Refactor document import to workflow-based architecture
Replaces the direct import logic with a workflow using Microsoft.Agents.AI.Workflows. Adds executors for file conversion, embedding generation, and storage. Updates dependency injection, API endpoint, and UI to use the new workflow. Replaces ImportDocumentResponse with StoreEmbeddingResponse. Adds required NuGet packages.
2026-06-17 12:35:30 +02:00
Marco Minerva 14f983307e Improve validation and naming in PlainTextChunker
Refactored PlainTextChunker to add input validation, null checks, and argument validation in public methods. Renamed SplitMarkDownLines to SplitMarkdownLines for consistency and updated all references. Centralized line ending normalization and token count validation into dedicated methods. Enhanced error handling with clearer messages and exception types. Simplified paragraph merging logic and ensured GetDefaultTokenCount returns at least 1 for non-empty input. Updated MarkdownTextChunker to use the corrected method name.
2026-06-16 15:55:11 +02:00
Marco Minerva 5624f73640 Refactor: replace SemanticKernel with Agents.AI.OpenAI
Removed Microsoft.SemanticKernel dependencies in favor of Microsoft.Agents.AI.OpenAI for embedding and chat services. Updated DI registrations in Program.cs to use OpenAIClient. Reimplemented text chunking with a new PlainTextChunker class, updating DefaultTextChunker and MarkdownTextChunker accordingly. Updated .csproj to add new package references and suppress related analyzer warnings.
2026-06-15 17:58:30 +02:00
12 changed files with 569 additions and 78 deletions
@@ -177,7 +177,7 @@ else
var vectorSearchService = scope.ServiceProvider.GetRequiredService<VectorSearchService>(); var vectorSearchService = scope.ServiceProvider.GetRequiredService<VectorSearchService>();
var documentId = string.IsNullOrWhiteSpace(Model.DocumentId) ? null : (Guid?)Guid.Parse(Model.DocumentId); var documentId = string.IsNullOrWhiteSpace(Model.DocumentId) ? null : (Guid?)Guid.Parse(Model.DocumentId);
await vectorSearchService.ImportAsync(stream, fileName, MimeUtility.GetMimeMapping(fileName), documentId); //await vectorSearchService.ImportAsync(stream, fileName, MimeUtility.GetMimeMapping(fileName), documentId);
ToastService.Notify(await CreateToastMessageAsync(ToastType.Success, "Upload document", $"The document {fileName} has been successfully uploaded and indexed.")); ToastService.Notify(await CreateToastMessageAsync(ToastType.Success, "Upload document", $"The document {fileName} has been successfully uploaded and indexed."));
@@ -1,9 +1,9 @@
using System.ComponentModel; using System.ComponentModel;
using Microsoft.AspNetCore.Http.HttpResults; using Microsoft.AspNetCore.Http.HttpResults;
using MimeMapping;
using SqlDatabaseVectorSearch.Models; using SqlDatabaseVectorSearch.Models;
using SqlDatabaseVectorSearch.Services; using SqlDatabaseVectorSearch.Services;
using SqlDatabaseVectorSearch.Workflows;
namespace SqlDatabaseVectorSearch.Endpoints; namespace SqlDatabaseVectorSearch.Endpoints;
@@ -23,12 +23,8 @@ public class DocumentEndpoints : IEndpointRouteHandlerBuilder
documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchService vectorSearchService, CancellationToken cancellationToken, documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchService vectorSearchService, CancellationToken cancellationToken,
[Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) => [Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) =>
{ {
using var stream = file.OpenReadStream(); var result = await vectorSearchService.ImportAsync(new FormFileEmbeddingRequest(file, documentId), cancellationToken);
return TypedResults.Ok(result);
// Note: file.ContentType is not 100% reliable (for example, for markdown file).
var response = await vectorSearchService.ImportAsync(stream, file.FileName, MimeUtility.GetMimeMapping(file.FileName), documentId, cancellationToken);
return TypedResults.Ok(response);
}) })
.DisableAntiforgery() .DisableAntiforgery()
.ProducesProblem(StatusCodes.Status400BadRequest) .ProducesProblem(StatusCodes.Status400BadRequest)
@@ -1,3 +0,0 @@
namespace SqlDatabaseVectorSearch.Models;
public record class ImportDocumentResponse(Guid DocumentId, int EmbeddingTokenCount);
+46 -1
View File
@@ -1,8 +1,14 @@
using System.ClientModel;
using System.Net.Mime; using System.Net.Mime;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using FluentValidation; using FluentValidation;
using Microsoft.Agents.AI.Hosting;
using Microsoft.Agents.AI.Workflows;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel; using Microsoft.SemanticKernel;
using OpenAI;
using OpenAI.Responses;
using SqlDatabaseVectorSearch.Components; using SqlDatabaseVectorSearch.Components;
using SqlDatabaseVectorSearch.ContentDecoders; using SqlDatabaseVectorSearch.ContentDecoders;
using SqlDatabaseVectorSearch.Data; using SqlDatabaseVectorSearch.Data;
@@ -10,6 +16,7 @@ using SqlDatabaseVectorSearch.Extensions;
using SqlDatabaseVectorSearch.Services; using SqlDatabaseVectorSearch.Services;
using SqlDatabaseVectorSearch.Settings; using SqlDatabaseVectorSearch.Settings;
using SqlDatabaseVectorSearch.TextChunkers; using SqlDatabaseVectorSearch.TextChunkers;
using SqlDatabaseVectorSearch.Workflows;
using TinyHelpers.AspNetCore.Extensions; using TinyHelpers.AspNetCore.Extensions;
using TinyHelpers.AspNetCore.OpenApi; using TinyHelpers.AspNetCore.OpenApi;
@@ -54,10 +61,29 @@ builder.Services.ConfigureHttpClientDefaults(configure =>
}); });
}); });
builder.Services.AddSingleton(_ =>
{
var embeddingClient = new OpenAIClient(new ApiKeyCredential(aiSettings.Embedding.ApiKey), new()
{
Endpoint = new(aiSettings.Embedding.Endpoint),
}).GetEmbeddingClient(aiSettings.Embedding.Deployment).AsIEmbeddingGenerator(aiSettings.Embedding.Dimensions);
return embeddingClient;
});
builder.Services.AddChatClient(_ =>
{
var chatClient = new OpenAIClient(new ApiKeyCredential(aiSettings.ChatCompletion.ApiKey), new()
{
Endpoint = new(aiSettings.ChatCompletion.Endpoint),
}).GetResponsesClient().AsIChatClientWithStoredOutputDisabled(aiSettings.ChatCompletion.Deployment);
return chatClient;
});
// Semantic Kernel is used to generate embeddings and to reformulate questions taking into account all the previous interactions, // Semantic Kernel is used to generate embeddings and to reformulate questions taking into account all the previous interactions,
// so that embeddings themselves can be generated more accurately. // so that embeddings themselves can be generated more accurately.
builder.Services.AddKernel() builder.Services.AddKernel()
.AddAzureOpenAIEmbeddingGenerator(aiSettings.Embedding.Deployment, aiSettings.Embedding.Endpoint, aiSettings.Embedding.ApiKey, modelId: aiSettings.Embedding.ModelId, dimensions: aiSettings.Embedding.Dimensions)
.AddAzureOpenAIChatCompletion(aiSettings.ChatCompletion.Deployment, aiSettings.ChatCompletion.Endpoint, aiSettings.ChatCompletion.ApiKey, modelId: aiSettings.ChatCompletion.ModelId); .AddAzureOpenAIChatCompletion(aiSettings.ChatCompletion.Deployment, aiSettings.ChatCompletion.Endpoint, aiSettings.ChatCompletion.ApiKey, modelId: aiSettings.ChatCompletion.ModelId);
builder.Services.AddKeyedSingleton<IContentDecoder, PdfContentDecoder>(MediaTypeNames.Application.Pdf); builder.Services.AddKeyedSingleton<IContentDecoder, PdfContentDecoder>(MediaTypeNames.Application.Pdf);
@@ -74,6 +100,25 @@ builder.Services.AddSingleton<ChatService>();
builder.Services.AddScoped<DocumentService>(); builder.Services.AddScoped<DocumentService>();
builder.Services.AddScoped<VectorSearchService>(); builder.Services.AddScoped<VectorSearchService>();
builder.Services.AddSingleton<FormFileToEmbeddingRequestExecutor>();
builder.Services.AddSingleton<GenerateEmbeddingExecutor>();
builder.Services.AddScoped<StoreEmbeddingExecutor>(); // This executor is registered as scoped because it uses the DbContext, which is also scoped.
builder.AddWorkflow("EmbeddingWorkflow", (services, key) =>
{
var formfileToConversionRequestExecutor = services.GetRequiredService<FormFileToEmbeddingRequestExecutor>();
var generateEmbeddingExecutor = services.GetRequiredService<GenerateEmbeddingExecutor>();
var storeEmbeddingExecutor = services.GetRequiredService<StoreEmbeddingExecutor>();
var workflow = new WorkflowBuilder(formfileToConversionRequestExecutor).WithName(key)
.AddEdge(formfileToConversionRequestExecutor, generateEmbeddingExecutor)
.AddEdge(generateEmbeddingExecutor, storeEmbeddingExecutor)
.WithOutputFrom(storeEmbeddingExecutor)
.Build(validateOrphans: true);
return workflow;
}, ServiceLifetime.Scoped);
builder.Services.AddOpenApi(options => builder.Services.AddOpenApi(options =>
{ {
options.RemoveServerList(); options.RemoveServerList();
@@ -2,84 +2,37 @@
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Text; using System.Text;
using System.Text.RegularExpressions; using System.Text.RegularExpressions;
using Microsoft.Agents.AI.Workflows;
using Microsoft.Data.SqlTypes; using Microsoft.Data.SqlTypes;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.AI; using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using SqlDatabaseVectorSearch.ContentDecoders;
using SqlDatabaseVectorSearch.Data; using SqlDatabaseVectorSearch.Data;
using SqlDatabaseVectorSearch.Models; using SqlDatabaseVectorSearch.Models;
using SqlDatabaseVectorSearch.Settings; using SqlDatabaseVectorSearch.Settings;
using SqlDatabaseVectorSearch.Workflows;
using ChatResponse = SqlDatabaseVectorSearch.Models.ChatResponse; using ChatResponse = SqlDatabaseVectorSearch.Models.ChatResponse;
using Entities = SqlDatabaseVectorSearch.Data.Entities; using Entities = SqlDatabaseVectorSearch.Data.Entities;
namespace SqlDatabaseVectorSearch.Services; namespace SqlDatabaseVectorSearch.Services;
public partial class VectorSearchService(IServiceProvider serviceProvider, ApplicationDbContext dbContext, DocumentService documentService, IEmbeddingGenerator<string, Embedding<float>> embeddingGenerator, TokenizerService tokenizerService, ChatService chatService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions, ILogger<VectorSearchService> logger) public partial class VectorSearchService([FromKeyedServices("EmbeddingWorkflow")] Workflow workflow, ApplicationDbContext dbContext, IEmbeddingGenerator<string, Embedding<float>> embeddingGenerator, TokenizerService tokenizerService, ChatService chatService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions, ILogger<VectorSearchService> logger)
{ {
private readonly AppSettings appSettings = appSettingsOptions.Value; private readonly AppSettings appSettings = appSettingsOptions.Value;
public async Task<ImportDocumentResponse> ImportAsync(Stream stream, string name, string contentType, Guid? documentId, CancellationToken cancellationToken = default) public async Task<StoreEmbeddingResponse> ImportAsync(FormFileEmbeddingRequest request, CancellationToken cancellationToken = default)
{ {
// Extract the contents of the file. await using var run = await InProcessExecution.RunAsync(workflow, request, cancellationToken: cancellationToken);
var decoder = serviceProvider.GetKeyedService<IContentDecoder>(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported."); var events = run.NewEvents.ToList();
var chunks = await decoder.DecodeAsync(stream, contentType, cancellationToken);
var chunkContents = chunks.Select(p => p.Content).ToList();
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis). var exception = events.OfType<WorkflowErrorEvent>().Select(e => e.Exception).FirstOrDefault();
var tokenCount = tokenizerService.CountEmbeddingTokens(string.Join(" ", chunkContents)); if (exception is not null)
var strategy = dbContext.Database.CreateExecutionStrategy();
var document = await strategy.ExecuteAsync(async (cancellationToken) =>
{ {
await dbContext.Database.BeginTransactionAsync(cancellationToken); throw exception;
if (documentId.HasValue)
{
// If the user is importing a document that already exists, delete the previous one.
await documentService.DeleteAsync(documentId.Value, cancellationToken);
} }
var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() }; var result = events.OfType<WorkflowOutputEvent>().Select(e => e.Data).OfType<StoreEmbeddingResponse>().First();
dbContext.Documents.Add(document); return result;
// Process paragraphs in batches.
var embeddings = new List<Embedding<float>>();
foreach (var batch in chunkContents.Chunk(appSettings.EmbeddingBatchSize))
{
logger.LogDebug("Processing batch of {Count} chunks for embedding generation...", batch.Length);
// Generate embeddings for this batch.
var batchEmbeddings = await embeddingGenerator.GenerateAsync(batch, cancellationToken: cancellationToken);
embeddings.AddRange(batchEmbeddings);
}
// Save the document chunks and the corresponding embedding in the database.
foreach (var (index, embedding) in embeddings.Index())
{
var chunk = chunks.ElementAt(index);
logger.LogDebug("Storing a chunk of {TokenCount} tokens.", tokenizerService.CountEmbeddingTokens(chunk.Content));
var documentChunk = new Entities.DocumentChunk
{
Document = document,
Index = index,
PageNumber = chunk.PageNumber,
IndexOnPage = chunk.IndexOnPage,
Content = chunk.Content,
Embedding = new SqlVector<float>(embedding.Vector)
};
dbContext.DocumentChunks.Add(documentChunk);
}
await dbContext.SaveChangesAsync(cancellationToken);
await dbContext.Database.CommitTransactionAsync(cancellationToken);
return document;
}, cancellationToken);
return new(document.Id, tokenCount);
} }
public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true, CancellationToken cancellationToken = default) public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true, CancellationToken cancellationToken = default)
@@ -4,7 +4,7 @@
<TargetFramework>net10.0</TargetFramework> <TargetFramework>net10.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings> <ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable> <Nullable>enable</Nullable>
<NoWarn>$(NoWarn);SKEXP0010;SKEXP0050</NoWarn> <NoWarn>$(NoWarn);SKEXP0010;SKEXP0050;OPENAI001;MAAI001</NoWarn>
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
@@ -12,7 +12,18 @@
<PackageReference Include="DocumentFormat.OpenXml" Version="3.5.1" /> <PackageReference Include="DocumentFormat.OpenXml" Version="3.5.1" />
<PackageReference Include="EntityFrameworkCore.Exceptions.SqlServer" Version="10.0.1" /> <PackageReference Include="EntityFrameworkCore.Exceptions.SqlServer" Version="10.0.1" />
<PackageReference Include="FluentValidation.DependencyInjectionExtensions" Version="12.1.1" /> <PackageReference Include="FluentValidation.DependencyInjectionExtensions" Version="12.1.1" />
<PackageReference Include="Microsoft.Agents.AI.Hosting" Version="1.10.0-preview.260610.1" />
<PackageReference Include="Microsoft.Agents.AI.OpenAI" Version="1.10.0" />
<PackageReference Include="Microsoft.Agents.AI.Workflows" Version="1.10.0" />
<PackageReference Include="Microsoft.Agents.AI.Workflows.Generators" Version="1.10.0">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="10.0.9" /> <PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="10.0.9" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Design" Version="10.0.9">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="10.0.9" /> <PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="10.0.9" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Tools" Version="10.0.9"> <PackageReference Include="Microsoft.EntityFrameworkCore.Tools" Version="10.0.9">
<PrivateAssets>all</PrivateAssets> <PrivateAssets>all</PrivateAssets>
@@ -1,7 +1,7 @@
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel.Text;
using SqlDatabaseVectorSearch.Services; using SqlDatabaseVectorSearch.Services;
using SqlDatabaseVectorSearch.Settings; using SqlDatabaseVectorSearch.Settings;
using SqlDatabaseVectorSearch.TextChunkers.Implementations;
namespace SqlDatabaseVectorSearch.TextChunkers; namespace SqlDatabaseVectorSearch.TextChunkers;
@@ -11,8 +11,8 @@ public class DefaultTextChunker(TokenizerService tokenizerService, IOptions<AppS
public IList<string> Split(string text) public IList<string> Split(string text)
{ {
var lines = TextChunker.SplitPlainTextLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountEmbeddingTokens); var lines = PlainTextChunker.SplitPlainTextLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountEmbeddingTokens);
var paragraphs = TextChunker.SplitPlainTextParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountEmbeddingTokens); var paragraphs = PlainTextChunker.SplitPlainTextParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountEmbeddingTokens);
return paragraphs; return paragraphs;
} }
@@ -0,0 +1,372 @@
using System.Diagnostics;
using System.Text;
namespace SqlDatabaseVectorSearch.TextChunkers.Implementations;
/// <summary>
/// Split text in chunks, attempting to leave meaning intact.
/// For plain text, split looking at new lines first, then periods, and so on.
/// For markdown, split looking at punctuation first, and so on.
/// </summary>
internal static class PlainTextChunker
{
/// <summary>
/// Represents a list of strings with token count.
/// Used to reduce the number of calls to the tokenizer.
/// </summary>
private sealed class StringListWithTokenCount(TokenCounter? tokenCounter)
{
private readonly TokenCounter? tokenCounter = tokenCounter;
public void Add(string value) => Values.Add((value, tokenCounter is null ? GetDefaultTokenCount(value.Length) : tokenCounter(value)));
public void Add(string value, int tokenCount) => Values.Add((value, tokenCount));
public void AddRange(StringListWithTokenCount range) => Values.AddRange(range.Values);
public void RemoveRange(int index, int count) => Values.RemoveRange(index, count);
public int Count => Values.Count;
public List<string> ToStringList() => Values.Select(v => v.Value).ToList();
private List<(string Value, int TokenCount)> Values { get; } = [];
public string ValueAt(int i) => Values[i].Value;
public int TokenCountAt(int i) => Values[i].TokenCount;
}
/// <summary>
/// Delegate for counting tokens in a string.
/// </summary>
/// <param name="input">The input string to count tokens in.</param>
/// <returns>The number of tokens in the input string.</returns>
public delegate int TokenCounter(string input);
private static readonly string?[] plainTextSplitOptions = ["\n", ".。.", "?!", ";", ":", ",,、", ")]}", " ", "-", null];
private static readonly string?[] markdownSplitOptions = [".\u3002\uFF0E", "?!", ";", ":", ",\uFF0C\u3001", ")]}", " ", "-", "\n\r", null];
/// <summary>
/// Split plain text into lines.
/// </summary>
/// <param name="text">Text to split</param>
/// <param name="maxTokensPerLine">Maximum number of tokens per line.</param>
/// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
/// <returns>List of lines.</returns>
public static List<string> SplitPlainTextLines(string text, int maxTokensPerLine, TokenCounter? tokenCounter = null)
{
ArgumentNullException.ThrowIfNull(text);
ValidateMaxTokens(maxTokensPerLine, nameof(maxTokensPerLine));
return InternalSplitLines(text, maxTokensPerLine, trim: true, plainTextSplitOptions, tokenCounter);
}
/// <summary>
/// Split markdown text into lines.
/// </summary>
/// <param name="text">Text to split</param>
/// <param name="maxTokensPerLine">Maximum number of tokens per line.</param>
/// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
/// <returns>List of lines.</returns>
public static List<string> SplitMarkdownLines(string text, int maxTokensPerLine, TokenCounter? tokenCounter = null)
{
ArgumentNullException.ThrowIfNull(text);
ValidateMaxTokens(maxTokensPerLine, nameof(maxTokensPerLine));
return InternalSplitLines(text, maxTokensPerLine, trim: true, markdownSplitOptions, tokenCounter);
}
/// <summary>
/// Split plain text into paragraphs.
/// </summary>
/// <param name="lines">Lines of text.</param>
/// <param name="maxTokensPerParagraph">Maximum number of tokens per paragraph.</param>
/// <param name="overlapTokens">Number of tokens to overlap between paragraphs.</param>
/// <param name="chunkHeader">Text to be prepended to each individual chunk.</param>
/// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
/// <returns>List of paragraphs.</returns>
public static List<string> SplitPlainTextParagraphs(IEnumerable<string> lines, int maxTokensPerParagraph, int overlapTokens = 0, string? chunkHeader = null, TokenCounter? tokenCounter = null)
=> InternalSplitTextParagraphs(lines, maxTokensPerParagraph, overlapTokens, chunkHeader,
static (text, maxTokens, tokenCounter) => InternalSplitLines(text, maxTokens, trim: false, plainTextSplitOptions, tokenCounter), tokenCounter);
/// <summary>
/// Split markdown text into paragraphs.
/// </summary>
/// <param name="lines">Lines of text.</param>
/// <param name="maxTokensPerParagraph">Maximum number of tokens per paragraph.</param>
/// <param name="overlapTokens">Number of tokens to overlap between paragraphs.</param>
/// <param name="chunkHeader">Text to be prepended to each individual chunk.</param>
/// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
/// <returns>List of paragraphs.</returns>
public static List<string> SplitMarkdownParagraphs(IEnumerable<string> lines, int maxTokensPerParagraph, int overlapTokens = 0, string? chunkHeader = null, TokenCounter? tokenCounter = null)
=> InternalSplitTextParagraphs(lines, maxTokensPerParagraph, overlapTokens, chunkHeader,
static (text, maxTokens, tokenCounter) => InternalSplitLines(text, maxTokens, trim: false, markdownSplitOptions, tokenCounter), tokenCounter);
private static List<string> InternalSplitTextParagraphs(IEnumerable<string> lines, int maxTokensPerParagraph, int overlapTokens, string? chunkHeader, Func<string, int, TokenCounter?, List<string>> longLinesSplitter, TokenCounter? tokenCounter)
{
ArgumentNullException.ThrowIfNull(lines);
ValidateMaxTokens(maxTokensPerParagraph, nameof(maxTokensPerParagraph));
if (overlapTokens < 0)
{
throw new ArgumentOutOfRangeException(nameof(overlapTokens), "overlapTokens cannot be negative.");
}
if (maxTokensPerParagraph <= overlapTokens)
{
throw new ArgumentException("overlapTokens cannot be larger than or equal to maxTokensPerParagraph.", nameof(overlapTokens));
}
// Optimize empty inputs if we can efficiently determine they're empty.
if (lines is ICollection<string> c && c.Count == 0)
{
return [];
}
var chunkHeaderTokens = chunkHeader is { Length: > 0 } ? GetTokenCount(chunkHeader, tokenCounter) : 0;
var adjustedMaxTokensPerParagraph = maxTokensPerParagraph - overlapTokens - chunkHeaderTokens;
if (adjustedMaxTokensPerParagraph <= 0)
{
throw new ArgumentException("chunkHeader and overlapTokens must leave room for paragraph content.", nameof(chunkHeader));
}
// Split long lines first
var truncatedLines = lines.SelectMany(line => longLinesSplitter(NormalizeLineEndings(line), adjustedMaxTokensPerParagraph, tokenCounter));
var paragraphs = BuildParagraph(truncatedLines, adjustedMaxTokensPerParagraph, tokenCounter);
var processedParagraphs = ProcessParagraphs(paragraphs, adjustedMaxTokensPerParagraph, overlapTokens, chunkHeader, longLinesSplitter, tokenCounter);
return processedParagraphs;
}
private static List<string> BuildParagraph(IEnumerable<string> truncatedLines, int maxTokensPerParagraph, TokenCounter? tokenCounter)
{
StringBuilder paragraphBuilder = new();
List<string> paragraphs = [];
foreach (var line in truncatedLines)
{
if (paragraphBuilder.Length > 0)
{
string? paragraph = null;
var currentCount = GetTokenCount(line, tokenCounter) + 1;
if (currentCount < maxTokensPerParagraph)
{
currentCount += tokenCounter is null ?
GetDefaultTokenCount(paragraphBuilder.Length) :
tokenCounter(paragraph = paragraphBuilder.ToString());
}
if (currentCount >= maxTokensPerParagraph)
{
// Complete the paragraph and prepare for the next
paragraph ??= paragraphBuilder.ToString();
paragraphs.Add(paragraph.Trim());
paragraphBuilder.Clear();
}
}
paragraphBuilder.AppendLine(line);
}
if (paragraphBuilder.Length > 0)
{
// Add the final paragraph if there's anything remaining
paragraphs.Add(paragraphBuilder.ToString().Trim());
}
return paragraphs;
}
private static List<string> ProcessParagraphs(List<string> paragraphs, int adjustedMaxTokensPerParagraph, int overlapTokens, string? chunkHeader, Func<string, int, TokenCounter?, List<string>> longLinesSplitter, TokenCounter? tokenCounter)
{
// distribute text more evenly in the last paragraphs when the last paragraph is too short.
if (paragraphs.Count > 1)
{
var lastParagraph = paragraphs[^1];
var secondLastParagraph = paragraphs[^2];
if (GetTokenCount(lastParagraph, tokenCounter) < adjustedMaxTokensPerParagraph / 4)
{
var mergedParagraph = $"{secondLastParagraph} {lastParagraph}";
if (GetTokenCount(mergedParagraph, tokenCounter) <= adjustedMaxTokensPerParagraph)
{
paragraphs[^2] = mergedParagraph;
paragraphs.RemoveAt(paragraphs.Count - 1);
}
}
}
var processedParagraphs = new List<string>();
var paragraphStringBuilder = new StringBuilder();
for (var i = 0; i < paragraphs.Count; i++)
{
paragraphStringBuilder.Clear();
if (chunkHeader is not null)
{
paragraphStringBuilder.Append(chunkHeader);
}
var paragraph = paragraphs[i];
if (overlapTokens > 0 && i < paragraphs.Count - 1)
{
var nextParagraph = paragraphs[i + 1];
var split = longLinesSplitter(nextParagraph, overlapTokens, tokenCounter);
paragraphStringBuilder.Append(paragraph);
if (split.Count != 0)
{
paragraphStringBuilder.Append(' ').Append(split[0]);
}
}
else
{
paragraphStringBuilder.Append(paragraph);
}
processedParagraphs.Add(paragraphStringBuilder.ToString());
}
return processedParagraphs;
}
private static List<string> InternalSplitLines(string text, int maxTokensPerLine, bool trim, string?[] splitOptions, TokenCounter? tokenCounter)
{
var result = new StringListWithTokenCount(tokenCounter);
text = NormalizeLineEndings(text);
result.Add(text);
for (var i = 0; i < splitOptions.Length; i++)
{
var count = result.Count; // track where the original input left off
var (splits2, inputWasSplit2) = Split(result, maxTokensPerLine, splitOptions[i].AsSpan(), trim, tokenCounter);
result.AddRange(splits2);
result.RemoveRange(0, count); // remove the original input
if (!inputWasSplit2)
{
break;
}
}
return result.ToStringList();
}
private static (StringListWithTokenCount, bool) Split(StringListWithTokenCount input, int maxTokens, ReadOnlySpan<char> separators, bool trim, TokenCounter? tokenCounter)
{
var inputWasSplit = false;
StringListWithTokenCount result = new(tokenCounter);
var count = input.Count;
for (var i = 0; i < count; i++)
{
var (splits, split) = Split(input.ValueAt(i).AsSpan(), input.ValueAt(i), maxTokens, separators, trim, tokenCounter, input.TokenCountAt(i));
result.AddRange(splits);
inputWasSplit |= split;
}
return (result, inputWasSplit);
}
private static (StringListWithTokenCount, bool) Split(ReadOnlySpan<char> input, string? inputString, int maxTokens, ReadOnlySpan<char> separators, bool trim, TokenCounter? tokenCounter, int inputTokenCount)
{
Debug.Assert(inputString is null || input.SequenceEqual(inputString.AsSpan()));
StringListWithTokenCount result = new(tokenCounter);
var inputWasSplit = false;
if (inputTokenCount > maxTokens)
{
inputWasSplit = true;
var half = input.Length / 2;
var cutPoint = -1;
if (separators.IsEmpty)
{
cutPoint = half;
}
else if (input.Length > 2)
{
var pos = 0;
while (true)
{
var index = input[pos..^1].IndexOfAny(separators);
if (index < 0)
{
break;
}
index += pos;
if (Math.Abs(half - index) < Math.Abs(half - cutPoint))
{
cutPoint = index + 1;
}
pos = index + 1;
}
}
if (cutPoint > 0)
{
var firstHalf = input[..cutPoint];
var secondHalf = input[cutPoint..];
if (trim)
{
firstHalf = firstHalf.Trim();
secondHalf = secondHalf.Trim();
}
// Recursion
var (splits1, split1) = Split(firstHalf, null, maxTokens, separators, trim, tokenCounter, GetTokenCount(firstHalf, tokenCounter));
result.AddRange(splits1);
var (splits2, split2) = Split(secondHalf, null, maxTokens, separators, trim, tokenCounter, GetTokenCount(secondHalf, tokenCounter));
result.AddRange(splits2);
inputWasSplit = split1 || split2;
return (result, inputWasSplit);
}
}
var resultString = inputString ?? input.ToString();
var resultTokenCount = inputTokenCount;
if (trim)
{
var trimmedResult = resultString.Trim();
if (!trimmedResult.Equals(resultString, StringComparison.Ordinal))
{
resultString = trimmedResult;
resultTokenCount = GetTokenCount(resultString, tokenCounter);
}
}
result.Add(resultString, resultTokenCount);
return (result, inputWasSplit);
}
private static int GetTokenCount(string input, TokenCounter? tokenCounter) => tokenCounter is null ? GetDefaultTokenCount(input.Length) : tokenCounter(input);
private static int GetTokenCount(ReadOnlySpan<char> input, TokenCounter? tokenCounter) => tokenCounter is null ? GetDefaultTokenCount(input.Length) : tokenCounter(input.ToString());
private static string NormalizeLineEndings(string text) => text.Replace("\r\n", "\n").Replace('\r', '\n');
private static void ValidateMaxTokens(int maxTokens, string parameterName)
{
if (maxTokens <= 0)
{
throw new ArgumentOutOfRangeException(parameterName, "The maximum token count must be a positive number.");
}
}
private static int GetDefaultTokenCount(int length)
{
Debug.Assert(length >= 0);
return length == 0 ? 0 : Math.Max(1, length >> 2);
}
}
@@ -1,7 +1,7 @@
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel.Text;
using SqlDatabaseVectorSearch.Services; using SqlDatabaseVectorSearch.Services;
using SqlDatabaseVectorSearch.Settings; using SqlDatabaseVectorSearch.Settings;
using SqlDatabaseVectorSearch.TextChunkers.Implementations;
namespace SqlDatabaseVectorSearch.TextChunkers; namespace SqlDatabaseVectorSearch.TextChunkers;
@@ -11,8 +11,8 @@ public class MarkdownTextChunker(TokenizerService tokenizerService, IOptions<App
public IList<string> Split(string text) public IList<string> Split(string text)
{ {
var lines = TextChunker.SplitMarkDownLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountEmbeddingTokens); var lines = PlainTextChunker.SplitMarkdownLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountEmbeddingTokens);
var paragraphs = TextChunker.SplitMarkdownParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountEmbeddingTokens); var paragraphs = PlainTextChunker.SplitMarkdownParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountEmbeddingTokens);
return paragraphs; return paragraphs;
} }
@@ -0,0 +1,18 @@
using Microsoft.Agents.AI.Workflows;
namespace SqlDatabaseVectorSearch.Workflows;
public partial class FormFileToEmbeddingRequestExecutor() : Executor(nameof(FormFileToEmbeddingRequestExecutor))
{
[MessageHandler]
private ValueTask<EmbeddingRequest> HandleAsync(FormFileEmbeddingRequest request, IWorkflowContext context, CancellationToken cancellationToken)
{
// Note: file.ContentType is not 100% reliable (for example, for markdown file).
var embeddingRequest = new EmbeddingRequest(request.File.OpenReadStream(), Path.GetFileName(request.File.FileName), MimeMapping.MimeUtility.GetMimeMapping(request.File.FileName), request.DocumentId);
return ValueTask.FromResult(embeddingRequest);
}
}
public record class FormFileEmbeddingRequest(IFormFile File, Guid? DocumentId);
public record class EmbeddingRequest(Stream Content, string FileName, string ContentType, Guid? DocumentId);
@@ -0,0 +1,41 @@
using Microsoft.Agents.AI.Workflows;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options;
using SqlDatabaseVectorSearch.ContentDecoders;
using SqlDatabaseVectorSearch.Models;
using SqlDatabaseVectorSearch.Services;
using SqlDatabaseVectorSearch.Settings;
namespace SqlDatabaseVectorSearch.Workflows;
public partial class GenerateEmbeddingExecutor(IServiceProvider serviceProvider, IEmbeddingGenerator<string, Embedding<float>> embeddingGenerator, TokenizerService tokenizerService, IOptions<AppSettings> appSettingsOptions, ILogger<VectorSearchService> logger) : Executor(nameof(GenerateEmbeddingExecutor))
{
private readonly AppSettings appSettings = appSettingsOptions.Value;
[MessageHandler]
private async ValueTask<EmbeddingResponse> HandleAsync(EmbeddingRequest request, IWorkflowContext context, CancellationToken cancellationToken)
{
// Extract the contents of the file.
var decoder = serviceProvider.GetKeyedService<IContentDecoder>(request.ContentType) ?? throw new NotSupportedException($"Content type '{request.ContentType}' is not supported.");
var chunks = await decoder.DecodeAsync(request.Content, request.ContentType, cancellationToken);
var chunkContents = chunks.Select(p => p.Content).ToList();
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
var tokenCount = tokenizerService.CountEmbeddingTokens(string.Join(" ", chunkContents));
// Process paragraphs in batches.
var embeddings = new List<Embedding<float>>();
foreach (var batch in chunkContents.Chunk(appSettings.EmbeddingBatchSize))
{
logger.LogDebug("Processing batch of {Count} chunks for embedding generation...", batch.Length);
// Generate embeddings for this batch.
var batchEmbeddings = await embeddingGenerator.GenerateAsync(batch, cancellationToken: cancellationToken);
embeddings.AddRange(batchEmbeddings);
}
return new EmbeddingResponse(request, chunks, embeddings, tokenCount);
}
}
public record class EmbeddingResponse(EmbeddingRequest Request, IEnumerable<Chunk> Chunks, IEnumerable<Embedding<float>> Embeddings, int TokenCount);
@@ -0,0 +1,58 @@
using Microsoft.Agents.AI.Workflows;
using Microsoft.Data.SqlTypes;
using Microsoft.EntityFrameworkCore;
using SqlDatabaseVectorSearch.Data;
using SqlDatabaseVectorSearch.Services;
using Entities = SqlDatabaseVectorSearch.Data.Entities;
namespace SqlDatabaseVectorSearch.Workflows;
public partial class StoreEmbeddingExecutor(ApplicationDbContext dbContext, DocumentService documentService, TokenizerService tokenizerService, TimeProvider timeProvider, ILogger<VectorSearchService> logger) : Executor(nameof(StoreEmbeddingExecutor))
{
[MessageHandler]
private async ValueTask<StoreEmbeddingResponse> HandleAsync(EmbeddingResponse embeddingData, IWorkflowContext context, CancellationToken cancellationToken)
{
var strategy = dbContext.Database.CreateExecutionStrategy();
var document = await strategy.ExecuteAsync(async (cancellationToken) =>
{
await dbContext.Database.BeginTransactionAsync(cancellationToken);
if (embeddingData.Request.DocumentId.HasValue)
{
// If the user is importing a document that already exists, delete the previous one.
await documentService.DeleteAsync(embeddingData.Request.DocumentId.Value, cancellationToken);
}
var document = new Entities.Document { Id = embeddingData.Request.DocumentId.GetValueOrDefault(), Name = embeddingData.Request.FileName, CreationDate = timeProvider.GetUtcNow() };
dbContext.Documents.Add(document);
// Save the document chunks and the corresponding embedding in the database.
foreach (var (index, embedding) in embeddingData.Embeddings.Index())
{
var chunk = embeddingData.Chunks.ElementAt(index);
logger.LogDebug("Storing a chunk of {TokenCount} tokens.", tokenizerService.CountEmbeddingTokens(chunk.Content));
var documentChunk = new Entities.DocumentChunk
{
Document = document,
Index = index,
PageNumber = chunk.PageNumber,
IndexOnPage = chunk.IndexOnPage,
Content = chunk.Content,
Embedding = new SqlVector<float>(embedding.Vector)
};
dbContext.DocumentChunks.Add(documentChunk);
}
await dbContext.SaveChangesAsync(cancellationToken);
await dbContext.Database.CommitTransactionAsync(cancellationToken);
return document;
}, cancellationToken);
return new(document.Id, embeddingData.TokenCount);
}
}
public record class StoreEmbeddingResponse(Guid DocumentId, int TokenCount);