mirror of
https://github.com/marcominerva/SqlDatabaseVectorSearch.git
synced 2026-06-20 12:23:10 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8a761ddcd7 | |||
| 14f983307e | |||
| 5624f73640 |
@@ -177,7 +177,7 @@ else
|
||||
var vectorSearchService = scope.ServiceProvider.GetRequiredService<VectorSearchService>();
|
||||
|
||||
var documentId = string.IsNullOrWhiteSpace(Model.DocumentId) ? null : (Guid?)Guid.Parse(Model.DocumentId);
|
||||
await vectorSearchService.ImportAsync(stream, fileName, MimeUtility.GetMimeMapping(fileName), documentId);
|
||||
//await vectorSearchService.ImportAsync(stream, fileName, MimeUtility.GetMimeMapping(fileName), documentId);
|
||||
|
||||
ToastService.Notify(await CreateToastMessageAsync(ToastType.Success, "Upload document", $"The document {fileName} has been successfully uploaded and indexed."));
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
|
||||
using System.ComponentModel;
|
||||
using Microsoft.AspNetCore.Http.HttpResults;
|
||||
using MimeMapping;
|
||||
using SqlDatabaseVectorSearch.Models;
|
||||
using SqlDatabaseVectorSearch.Services;
|
||||
using SqlDatabaseVectorSearch.Workflows;
|
||||
|
||||
namespace SqlDatabaseVectorSearch.Endpoints;
|
||||
|
||||
@@ -23,12 +23,8 @@ public class DocumentEndpoints : IEndpointRouteHandlerBuilder
|
||||
documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchService vectorSearchService, CancellationToken cancellationToken,
|
||||
[Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) =>
|
||||
{
|
||||
using var stream = file.OpenReadStream();
|
||||
|
||||
// Note: file.ContentType is not 100% reliable (for example, for markdown file).
|
||||
var response = await vectorSearchService.ImportAsync(stream, file.FileName, MimeUtility.GetMimeMapping(file.FileName), documentId, cancellationToken);
|
||||
|
||||
return TypedResults.Ok(response);
|
||||
var result = await vectorSearchService.ImportAsync(new FormFileEmbeddingRequest(file, documentId), cancellationToken);
|
||||
return TypedResults.Ok(result);
|
||||
})
|
||||
.DisableAntiforgery()
|
||||
.ProducesProblem(StatusCodes.Status400BadRequest)
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
namespace SqlDatabaseVectorSearch.Models;
|
||||
|
||||
public record class ImportDocumentResponse(Guid DocumentId, int EmbeddingTokenCount);
|
||||
@@ -1,8 +1,14 @@
|
||||
using System.ClientModel;
|
||||
using System.Net.Mime;
|
||||
using System.Text.Json.Serialization;
|
||||
using FluentValidation;
|
||||
using Microsoft.Agents.AI.Hosting;
|
||||
using Microsoft.Agents.AI.Workflows;
|
||||
using Microsoft.EntityFrameworkCore;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.SemanticKernel;
|
||||
using OpenAI;
|
||||
using OpenAI.Responses;
|
||||
using SqlDatabaseVectorSearch.Components;
|
||||
using SqlDatabaseVectorSearch.ContentDecoders;
|
||||
using SqlDatabaseVectorSearch.Data;
|
||||
@@ -10,6 +16,7 @@ using SqlDatabaseVectorSearch.Extensions;
|
||||
using SqlDatabaseVectorSearch.Services;
|
||||
using SqlDatabaseVectorSearch.Settings;
|
||||
using SqlDatabaseVectorSearch.TextChunkers;
|
||||
using SqlDatabaseVectorSearch.Workflows;
|
||||
using TinyHelpers.AspNetCore.Extensions;
|
||||
using TinyHelpers.AspNetCore.OpenApi;
|
||||
|
||||
@@ -54,10 +61,29 @@ builder.Services.ConfigureHttpClientDefaults(configure =>
|
||||
});
|
||||
});
|
||||
|
||||
builder.Services.AddSingleton(_ =>
|
||||
{
|
||||
var embeddingClient = new OpenAIClient(new ApiKeyCredential(aiSettings.Embedding.ApiKey), new()
|
||||
{
|
||||
Endpoint = new(aiSettings.Embedding.Endpoint),
|
||||
}).GetEmbeddingClient(aiSettings.Embedding.Deployment).AsIEmbeddingGenerator(aiSettings.Embedding.Dimensions);
|
||||
|
||||
return embeddingClient;
|
||||
});
|
||||
|
||||
builder.Services.AddChatClient(_ =>
|
||||
{
|
||||
var chatClient = new OpenAIClient(new ApiKeyCredential(aiSettings.ChatCompletion.ApiKey), new()
|
||||
{
|
||||
Endpoint = new(aiSettings.ChatCompletion.Endpoint),
|
||||
}).GetResponsesClient().AsIChatClientWithStoredOutputDisabled(aiSettings.ChatCompletion.Deployment);
|
||||
|
||||
return chatClient;
|
||||
});
|
||||
|
||||
// Semantic Kernel is used to generate embeddings and to reformulate questions taking into account all the previous interactions,
|
||||
// so that embeddings themselves can be generated more accurately.
|
||||
builder.Services.AddKernel()
|
||||
.AddAzureOpenAIEmbeddingGenerator(aiSettings.Embedding.Deployment, aiSettings.Embedding.Endpoint, aiSettings.Embedding.ApiKey, modelId: aiSettings.Embedding.ModelId, dimensions: aiSettings.Embedding.Dimensions)
|
||||
.AddAzureOpenAIChatCompletion(aiSettings.ChatCompletion.Deployment, aiSettings.ChatCompletion.Endpoint, aiSettings.ChatCompletion.ApiKey, modelId: aiSettings.ChatCompletion.ModelId);
|
||||
|
||||
builder.Services.AddKeyedSingleton<IContentDecoder, PdfContentDecoder>(MediaTypeNames.Application.Pdf);
|
||||
@@ -74,6 +100,25 @@ builder.Services.AddSingleton<ChatService>();
|
||||
builder.Services.AddScoped<DocumentService>();
|
||||
builder.Services.AddScoped<VectorSearchService>();
|
||||
|
||||
builder.Services.AddSingleton<FormFileToEmbeddingRequestExecutor>();
|
||||
builder.Services.AddSingleton<GenerateEmbeddingExecutor>();
|
||||
builder.Services.AddScoped<StoreEmbeddingExecutor>(); // This executor is registered as scoped because it uses the DbContext, which is also scoped.
|
||||
|
||||
builder.AddWorkflow("EmbeddingWorkflow", (services, key) =>
|
||||
{
|
||||
var formfileToConversionRequestExecutor = services.GetRequiredService<FormFileToEmbeddingRequestExecutor>();
|
||||
var generateEmbeddingExecutor = services.GetRequiredService<GenerateEmbeddingExecutor>();
|
||||
var storeEmbeddingExecutor = services.GetRequiredService<StoreEmbeddingExecutor>();
|
||||
|
||||
var workflow = new WorkflowBuilder(formfileToConversionRequestExecutor).WithName(key)
|
||||
.AddEdge(formfileToConversionRequestExecutor, generateEmbeddingExecutor)
|
||||
.AddEdge(generateEmbeddingExecutor, storeEmbeddingExecutor)
|
||||
.WithOutputFrom(storeEmbeddingExecutor)
|
||||
.Build(validateOrphans: true);
|
||||
|
||||
return workflow;
|
||||
}, ServiceLifetime.Scoped);
|
||||
|
||||
builder.Services.AddOpenApi(options =>
|
||||
{
|
||||
options.RemoveServerList();
|
||||
|
||||
@@ -2,84 +2,37 @@
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Text;
|
||||
using System.Text.RegularExpressions;
|
||||
using Microsoft.Agents.AI.Workflows;
|
||||
using Microsoft.Data.SqlTypes;
|
||||
using Microsoft.EntityFrameworkCore;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.Extensions.Options;
|
||||
using SqlDatabaseVectorSearch.ContentDecoders;
|
||||
using SqlDatabaseVectorSearch.Data;
|
||||
using SqlDatabaseVectorSearch.Models;
|
||||
using SqlDatabaseVectorSearch.Settings;
|
||||
using SqlDatabaseVectorSearch.Workflows;
|
||||
using ChatResponse = SqlDatabaseVectorSearch.Models.ChatResponse;
|
||||
using Entities = SqlDatabaseVectorSearch.Data.Entities;
|
||||
|
||||
namespace SqlDatabaseVectorSearch.Services;
|
||||
|
||||
public partial class VectorSearchService(IServiceProvider serviceProvider, ApplicationDbContext dbContext, DocumentService documentService, IEmbeddingGenerator<string, Embedding<float>> embeddingGenerator, TokenizerService tokenizerService, ChatService chatService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions, ILogger<VectorSearchService> logger)
|
||||
public partial class VectorSearchService([FromKeyedServices("EmbeddingWorkflow")] Workflow workflow, ApplicationDbContext dbContext, IEmbeddingGenerator<string, Embedding<float>> embeddingGenerator, TokenizerService tokenizerService, ChatService chatService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions, ILogger<VectorSearchService> logger)
|
||||
{
|
||||
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
||||
|
||||
public async Task<ImportDocumentResponse> ImportAsync(Stream stream, string name, string contentType, Guid? documentId, CancellationToken cancellationToken = default)
|
||||
public async Task<StoreEmbeddingResponse> ImportAsync(FormFileEmbeddingRequest request, CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Extract the contents of the file.
|
||||
var decoder = serviceProvider.GetKeyedService<IContentDecoder>(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported.");
|
||||
var chunks = await decoder.DecodeAsync(stream, contentType, cancellationToken);
|
||||
var chunkContents = chunks.Select(p => p.Content).ToList();
|
||||
await using var run = await InProcessExecution.RunAsync(workflow, request, cancellationToken: cancellationToken);
|
||||
var events = run.NewEvents.ToList();
|
||||
|
||||
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
|
||||
var tokenCount = tokenizerService.CountEmbeddingTokens(string.Join(" ", chunkContents));
|
||||
|
||||
var strategy = dbContext.Database.CreateExecutionStrategy();
|
||||
var document = await strategy.ExecuteAsync(async (cancellationToken) =>
|
||||
var exception = events.OfType<WorkflowErrorEvent>().Select(e => e.Exception).FirstOrDefault();
|
||||
if (exception is not null)
|
||||
{
|
||||
await dbContext.Database.BeginTransactionAsync(cancellationToken);
|
||||
throw exception;
|
||||
}
|
||||
|
||||
if (documentId.HasValue)
|
||||
{
|
||||
// If the user is importing a document that already exists, delete the previous one.
|
||||
await documentService.DeleteAsync(documentId.Value, cancellationToken);
|
||||
}
|
||||
|
||||
var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() };
|
||||
dbContext.Documents.Add(document);
|
||||
|
||||
// Process paragraphs in batches.
|
||||
var embeddings = new List<Embedding<float>>();
|
||||
foreach (var batch in chunkContents.Chunk(appSettings.EmbeddingBatchSize))
|
||||
{
|
||||
logger.LogDebug("Processing batch of {Count} chunks for embedding generation...", batch.Length);
|
||||
|
||||
// Generate embeddings for this batch.
|
||||
var batchEmbeddings = await embeddingGenerator.GenerateAsync(batch, cancellationToken: cancellationToken);
|
||||
embeddings.AddRange(batchEmbeddings);
|
||||
}
|
||||
|
||||
// Save the document chunks and the corresponding embedding in the database.
|
||||
foreach (var (index, embedding) in embeddings.Index())
|
||||
{
|
||||
var chunk = chunks.ElementAt(index);
|
||||
logger.LogDebug("Storing a chunk of {TokenCount} tokens.", tokenizerService.CountEmbeddingTokens(chunk.Content));
|
||||
|
||||
var documentChunk = new Entities.DocumentChunk
|
||||
{
|
||||
Document = document,
|
||||
Index = index,
|
||||
PageNumber = chunk.PageNumber,
|
||||
IndexOnPage = chunk.IndexOnPage,
|
||||
Content = chunk.Content,
|
||||
Embedding = new SqlVector<float>(embedding.Vector)
|
||||
};
|
||||
|
||||
dbContext.DocumentChunks.Add(documentChunk);
|
||||
}
|
||||
|
||||
await dbContext.SaveChangesAsync(cancellationToken);
|
||||
await dbContext.Database.CommitTransactionAsync(cancellationToken);
|
||||
|
||||
return document;
|
||||
}, cancellationToken);
|
||||
|
||||
return new(document.Id, tokenCount);
|
||||
var result = events.OfType<WorkflowOutputEvent>().Select(e => e.Data).OfType<StoreEmbeddingResponse>().First();
|
||||
return result;
|
||||
}
|
||||
|
||||
public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true, CancellationToken cancellationToken = default)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
<TargetFramework>net10.0</TargetFramework>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
<NoWarn>$(NoWarn);SKEXP0010;SKEXP0050</NoWarn>
|
||||
<NoWarn>$(NoWarn);SKEXP0010;SKEXP0050;OPENAI001;MAAI001</NoWarn>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
@@ -12,7 +12,18 @@
|
||||
<PackageReference Include="DocumentFormat.OpenXml" Version="3.5.1" />
|
||||
<PackageReference Include="EntityFrameworkCore.Exceptions.SqlServer" Version="10.0.1" />
|
||||
<PackageReference Include="FluentValidation.DependencyInjectionExtensions" Version="12.1.1" />
|
||||
<PackageReference Include="Microsoft.Agents.AI.Hosting" Version="1.10.0-preview.260610.1" />
|
||||
<PackageReference Include="Microsoft.Agents.AI.OpenAI" Version="1.10.0" />
|
||||
<PackageReference Include="Microsoft.Agents.AI.Workflows" Version="1.10.0" />
|
||||
<PackageReference Include="Microsoft.Agents.AI.Workflows.Generators" Version="1.10.0">
|
||||
<PrivateAssets>all</PrivateAssets>
|
||||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
|
||||
</PackageReference>
|
||||
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="10.0.9" />
|
||||
<PackageReference Include="Microsoft.EntityFrameworkCore.Design" Version="10.0.9">
|
||||
<PrivateAssets>all</PrivateAssets>
|
||||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
|
||||
</PackageReference>
|
||||
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="10.0.9" />
|
||||
<PackageReference Include="Microsoft.EntityFrameworkCore.Tools" Version="10.0.9">
|
||||
<PrivateAssets>all</PrivateAssets>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
using Microsoft.Extensions.Options;
|
||||
using Microsoft.SemanticKernel.Text;
|
||||
using SqlDatabaseVectorSearch.Services;
|
||||
using SqlDatabaseVectorSearch.Settings;
|
||||
using SqlDatabaseVectorSearch.TextChunkers.Implementations;
|
||||
|
||||
namespace SqlDatabaseVectorSearch.TextChunkers;
|
||||
|
||||
@@ -11,8 +11,8 @@ public class DefaultTextChunker(TokenizerService tokenizerService, IOptions<AppS
|
||||
|
||||
public IList<string> Split(string text)
|
||||
{
|
||||
var lines = TextChunker.SplitPlainTextLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountEmbeddingTokens);
|
||||
var paragraphs = TextChunker.SplitPlainTextParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountEmbeddingTokens);
|
||||
var lines = PlainTextChunker.SplitPlainTextLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountEmbeddingTokens);
|
||||
var paragraphs = PlainTextChunker.SplitPlainTextParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountEmbeddingTokens);
|
||||
|
||||
return paragraphs;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,372 @@
|
||||
using System.Diagnostics;
|
||||
using System.Text;
|
||||
|
||||
namespace SqlDatabaseVectorSearch.TextChunkers.Implementations;
|
||||
|
||||
/// <summary>
|
||||
/// Split text in chunks, attempting to leave meaning intact.
|
||||
/// For plain text, split looking at new lines first, then periods, and so on.
|
||||
/// For markdown, split looking at punctuation first, and so on.
|
||||
/// </summary>
|
||||
internal static class PlainTextChunker
|
||||
{
|
||||
/// <summary>
|
||||
/// Represents a list of strings with token count.
|
||||
/// Used to reduce the number of calls to the tokenizer.
|
||||
/// </summary>
|
||||
private sealed class StringListWithTokenCount(TokenCounter? tokenCounter)
|
||||
{
|
||||
private readonly TokenCounter? tokenCounter = tokenCounter;
|
||||
|
||||
public void Add(string value) => Values.Add((value, tokenCounter is null ? GetDefaultTokenCount(value.Length) : tokenCounter(value)));
|
||||
|
||||
public void Add(string value, int tokenCount) => Values.Add((value, tokenCount));
|
||||
|
||||
public void AddRange(StringListWithTokenCount range) => Values.AddRange(range.Values);
|
||||
|
||||
public void RemoveRange(int index, int count) => Values.RemoveRange(index, count);
|
||||
|
||||
public int Count => Values.Count;
|
||||
|
||||
public List<string> ToStringList() => Values.Select(v => v.Value).ToList();
|
||||
|
||||
private List<(string Value, int TokenCount)> Values { get; } = [];
|
||||
|
||||
public string ValueAt(int i) => Values[i].Value;
|
||||
|
||||
public int TokenCountAt(int i) => Values[i].TokenCount;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Delegate for counting tokens in a string.
|
||||
/// </summary>
|
||||
/// <param name="input">The input string to count tokens in.</param>
|
||||
/// <returns>The number of tokens in the input string.</returns>
|
||||
public delegate int TokenCounter(string input);
|
||||
|
||||
private static readonly string?[] plainTextSplitOptions = ["\n", ".。.", "?!", ";", ":", ",,、", ")]}", " ", "-", null];
|
||||
private static readonly string?[] markdownSplitOptions = [".\u3002\uFF0E", "?!", ";", ":", ",\uFF0C\u3001", ")]}", " ", "-", "\n\r", null];
|
||||
|
||||
/// <summary>
|
||||
/// Split plain text into lines.
|
||||
/// </summary>
|
||||
/// <param name="text">Text to split</param>
|
||||
/// <param name="maxTokensPerLine">Maximum number of tokens per line.</param>
|
||||
/// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
|
||||
/// <returns>List of lines.</returns>
|
||||
public static List<string> SplitPlainTextLines(string text, int maxTokensPerLine, TokenCounter? tokenCounter = null)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(text);
|
||||
ValidateMaxTokens(maxTokensPerLine, nameof(maxTokensPerLine));
|
||||
|
||||
return InternalSplitLines(text, maxTokensPerLine, trim: true, plainTextSplitOptions, tokenCounter);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Split markdown text into lines.
|
||||
/// </summary>
|
||||
/// <param name="text">Text to split</param>
|
||||
/// <param name="maxTokensPerLine">Maximum number of tokens per line.</param>
|
||||
/// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
|
||||
/// <returns>List of lines.</returns>
|
||||
public static List<string> SplitMarkdownLines(string text, int maxTokensPerLine, TokenCounter? tokenCounter = null)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(text);
|
||||
ValidateMaxTokens(maxTokensPerLine, nameof(maxTokensPerLine));
|
||||
|
||||
return InternalSplitLines(text, maxTokensPerLine, trim: true, markdownSplitOptions, tokenCounter);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Split plain text into paragraphs.
|
||||
/// </summary>
|
||||
/// <param name="lines">Lines of text.</param>
|
||||
/// <param name="maxTokensPerParagraph">Maximum number of tokens per paragraph.</param>
|
||||
/// <param name="overlapTokens">Number of tokens to overlap between paragraphs.</param>
|
||||
/// <param name="chunkHeader">Text to be prepended to each individual chunk.</param>
|
||||
/// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
|
||||
/// <returns>List of paragraphs.</returns>
|
||||
public static List<string> SplitPlainTextParagraphs(IEnumerable<string> lines, int maxTokensPerParagraph, int overlapTokens = 0, string? chunkHeader = null, TokenCounter? tokenCounter = null)
|
||||
=> InternalSplitTextParagraphs(lines, maxTokensPerParagraph, overlapTokens, chunkHeader,
|
||||
static (text, maxTokens, tokenCounter) => InternalSplitLines(text, maxTokens, trim: false, plainTextSplitOptions, tokenCounter), tokenCounter);
|
||||
|
||||
/// <summary>
|
||||
/// Split markdown text into paragraphs.
|
||||
/// </summary>
|
||||
/// <param name="lines">Lines of text.</param>
|
||||
/// <param name="maxTokensPerParagraph">Maximum number of tokens per paragraph.</param>
|
||||
/// <param name="overlapTokens">Number of tokens to overlap between paragraphs.</param>
|
||||
/// <param name="chunkHeader">Text to be prepended to each individual chunk.</param>
|
||||
/// <param name="tokenCounter">Function to count tokens in a string. If not supplied, the default counter will be used.</param>
|
||||
/// <returns>List of paragraphs.</returns>
|
||||
public static List<string> SplitMarkdownParagraphs(IEnumerable<string> lines, int maxTokensPerParagraph, int overlapTokens = 0, string? chunkHeader = null, TokenCounter? tokenCounter = null)
|
||||
=> InternalSplitTextParagraphs(lines, maxTokensPerParagraph, overlapTokens, chunkHeader,
|
||||
static (text, maxTokens, tokenCounter) => InternalSplitLines(text, maxTokens, trim: false, markdownSplitOptions, tokenCounter), tokenCounter);
|
||||
|
||||
private static List<string> InternalSplitTextParagraphs(IEnumerable<string> lines, int maxTokensPerParagraph, int overlapTokens, string? chunkHeader, Func<string, int, TokenCounter?, List<string>> longLinesSplitter, TokenCounter? tokenCounter)
|
||||
{
|
||||
ArgumentNullException.ThrowIfNull(lines);
|
||||
ValidateMaxTokens(maxTokensPerParagraph, nameof(maxTokensPerParagraph));
|
||||
|
||||
if (overlapTokens < 0)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(nameof(overlapTokens), "overlapTokens cannot be negative.");
|
||||
}
|
||||
|
||||
if (maxTokensPerParagraph <= overlapTokens)
|
||||
{
|
||||
throw new ArgumentException("overlapTokens cannot be larger than or equal to maxTokensPerParagraph.", nameof(overlapTokens));
|
||||
}
|
||||
|
||||
// Optimize empty inputs if we can efficiently determine they're empty.
|
||||
if (lines is ICollection<string> c && c.Count == 0)
|
||||
{
|
||||
return [];
|
||||
}
|
||||
|
||||
var chunkHeaderTokens = chunkHeader is { Length: > 0 } ? GetTokenCount(chunkHeader, tokenCounter) : 0;
|
||||
var adjustedMaxTokensPerParagraph = maxTokensPerParagraph - overlapTokens - chunkHeaderTokens;
|
||||
if (adjustedMaxTokensPerParagraph <= 0)
|
||||
{
|
||||
throw new ArgumentException("chunkHeader and overlapTokens must leave room for paragraph content.", nameof(chunkHeader));
|
||||
}
|
||||
|
||||
// Split long lines first
|
||||
var truncatedLines = lines.SelectMany(line => longLinesSplitter(NormalizeLineEndings(line), adjustedMaxTokensPerParagraph, tokenCounter));
|
||||
|
||||
var paragraphs = BuildParagraph(truncatedLines, adjustedMaxTokensPerParagraph, tokenCounter);
|
||||
var processedParagraphs = ProcessParagraphs(paragraphs, adjustedMaxTokensPerParagraph, overlapTokens, chunkHeader, longLinesSplitter, tokenCounter);
|
||||
|
||||
return processedParagraphs;
|
||||
}
|
||||
|
||||
private static List<string> BuildParagraph(IEnumerable<string> truncatedLines, int maxTokensPerParagraph, TokenCounter? tokenCounter)
|
||||
{
|
||||
StringBuilder paragraphBuilder = new();
|
||||
List<string> paragraphs = [];
|
||||
|
||||
foreach (var line in truncatedLines)
|
||||
{
|
||||
if (paragraphBuilder.Length > 0)
|
||||
{
|
||||
string? paragraph = null;
|
||||
|
||||
var currentCount = GetTokenCount(line, tokenCounter) + 1;
|
||||
if (currentCount < maxTokensPerParagraph)
|
||||
{
|
||||
currentCount += tokenCounter is null ?
|
||||
GetDefaultTokenCount(paragraphBuilder.Length) :
|
||||
tokenCounter(paragraph = paragraphBuilder.ToString());
|
||||
}
|
||||
|
||||
if (currentCount >= maxTokensPerParagraph)
|
||||
{
|
||||
// Complete the paragraph and prepare for the next
|
||||
paragraph ??= paragraphBuilder.ToString();
|
||||
paragraphs.Add(paragraph.Trim());
|
||||
paragraphBuilder.Clear();
|
||||
}
|
||||
}
|
||||
|
||||
paragraphBuilder.AppendLine(line);
|
||||
}
|
||||
|
||||
if (paragraphBuilder.Length > 0)
|
||||
{
|
||||
// Add the final paragraph if there's anything remaining
|
||||
paragraphs.Add(paragraphBuilder.ToString().Trim());
|
||||
}
|
||||
|
||||
return paragraphs;
|
||||
}
|
||||
|
||||
private static List<string> ProcessParagraphs(List<string> paragraphs, int adjustedMaxTokensPerParagraph, int overlapTokens, string? chunkHeader, Func<string, int, TokenCounter?, List<string>> longLinesSplitter, TokenCounter? tokenCounter)
|
||||
{
|
||||
// distribute text more evenly in the last paragraphs when the last paragraph is too short.
|
||||
if (paragraphs.Count > 1)
|
||||
{
|
||||
var lastParagraph = paragraphs[^1];
|
||||
var secondLastParagraph = paragraphs[^2];
|
||||
|
||||
if (GetTokenCount(lastParagraph, tokenCounter) < adjustedMaxTokensPerParagraph / 4)
|
||||
{
|
||||
var mergedParagraph = $"{secondLastParagraph} {lastParagraph}";
|
||||
|
||||
if (GetTokenCount(mergedParagraph, tokenCounter) <= adjustedMaxTokensPerParagraph)
|
||||
{
|
||||
paragraphs[^2] = mergedParagraph;
|
||||
paragraphs.RemoveAt(paragraphs.Count - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var processedParagraphs = new List<string>();
|
||||
var paragraphStringBuilder = new StringBuilder();
|
||||
|
||||
for (var i = 0; i < paragraphs.Count; i++)
|
||||
{
|
||||
paragraphStringBuilder.Clear();
|
||||
|
||||
if (chunkHeader is not null)
|
||||
{
|
||||
paragraphStringBuilder.Append(chunkHeader);
|
||||
}
|
||||
|
||||
var paragraph = paragraphs[i];
|
||||
|
||||
if (overlapTokens > 0 && i < paragraphs.Count - 1)
|
||||
{
|
||||
var nextParagraph = paragraphs[i + 1];
|
||||
var split = longLinesSplitter(nextParagraph, overlapTokens, tokenCounter);
|
||||
|
||||
paragraphStringBuilder.Append(paragraph);
|
||||
|
||||
if (split.Count != 0)
|
||||
{
|
||||
paragraphStringBuilder.Append(' ').Append(split[0]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
paragraphStringBuilder.Append(paragraph);
|
||||
}
|
||||
|
||||
processedParagraphs.Add(paragraphStringBuilder.ToString());
|
||||
}
|
||||
|
||||
return processedParagraphs;
|
||||
}
|
||||
|
||||
private static List<string> InternalSplitLines(string text, int maxTokensPerLine, bool trim, string?[] splitOptions, TokenCounter? tokenCounter)
|
||||
{
|
||||
var result = new StringListWithTokenCount(tokenCounter);
|
||||
|
||||
text = NormalizeLineEndings(text);
|
||||
result.Add(text);
|
||||
for (var i = 0; i < splitOptions.Length; i++)
|
||||
{
|
||||
var count = result.Count; // track where the original input left off
|
||||
var (splits2, inputWasSplit2) = Split(result, maxTokensPerLine, splitOptions[i].AsSpan(), trim, tokenCounter);
|
||||
result.AddRange(splits2);
|
||||
result.RemoveRange(0, count); // remove the original input
|
||||
if (!inputWasSplit2)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return result.ToStringList();
|
||||
}
|
||||
|
||||
private static (StringListWithTokenCount, bool) Split(StringListWithTokenCount input, int maxTokens, ReadOnlySpan<char> separators, bool trim, TokenCounter? tokenCounter)
|
||||
{
|
||||
var inputWasSplit = false;
|
||||
StringListWithTokenCount result = new(tokenCounter);
|
||||
var count = input.Count;
|
||||
for (var i = 0; i < count; i++)
|
||||
{
|
||||
var (splits, split) = Split(input.ValueAt(i).AsSpan(), input.ValueAt(i), maxTokens, separators, trim, tokenCounter, input.TokenCountAt(i));
|
||||
result.AddRange(splits);
|
||||
inputWasSplit |= split;
|
||||
}
|
||||
|
||||
return (result, inputWasSplit);
|
||||
}
|
||||
|
||||
private static (StringListWithTokenCount, bool) Split(ReadOnlySpan<char> input, string? inputString, int maxTokens, ReadOnlySpan<char> separators, bool trim, TokenCounter? tokenCounter, int inputTokenCount)
|
||||
{
|
||||
Debug.Assert(inputString is null || input.SequenceEqual(inputString.AsSpan()));
|
||||
StringListWithTokenCount result = new(tokenCounter);
|
||||
var inputWasSplit = false;
|
||||
|
||||
if (inputTokenCount > maxTokens)
|
||||
{
|
||||
inputWasSplit = true;
|
||||
|
||||
var half = input.Length / 2;
|
||||
var cutPoint = -1;
|
||||
|
||||
if (separators.IsEmpty)
|
||||
{
|
||||
cutPoint = half;
|
||||
}
|
||||
else if (input.Length > 2)
|
||||
{
|
||||
var pos = 0;
|
||||
while (true)
|
||||
{
|
||||
var index = input[pos..^1].IndexOfAny(separators);
|
||||
if (index < 0)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
index += pos;
|
||||
|
||||
if (Math.Abs(half - index) < Math.Abs(half - cutPoint))
|
||||
{
|
||||
cutPoint = index + 1;
|
||||
}
|
||||
|
||||
pos = index + 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (cutPoint > 0)
|
||||
{
|
||||
var firstHalf = input[..cutPoint];
|
||||
var secondHalf = input[cutPoint..];
|
||||
if (trim)
|
||||
{
|
||||
firstHalf = firstHalf.Trim();
|
||||
secondHalf = secondHalf.Trim();
|
||||
}
|
||||
|
||||
// Recursion
|
||||
var (splits1, split1) = Split(firstHalf, null, maxTokens, separators, trim, tokenCounter, GetTokenCount(firstHalf, tokenCounter));
|
||||
result.AddRange(splits1);
|
||||
var (splits2, split2) = Split(secondHalf, null, maxTokens, separators, trim, tokenCounter, GetTokenCount(secondHalf, tokenCounter));
|
||||
result.AddRange(splits2);
|
||||
|
||||
inputWasSplit = split1 || split2;
|
||||
return (result, inputWasSplit);
|
||||
}
|
||||
}
|
||||
|
||||
var resultString = inputString ?? input.ToString();
|
||||
var resultTokenCount = inputTokenCount;
|
||||
if (trim)
|
||||
{
|
||||
var trimmedResult = resultString.Trim();
|
||||
if (!trimmedResult.Equals(resultString, StringComparison.Ordinal))
|
||||
{
|
||||
resultString = trimmedResult;
|
||||
resultTokenCount = GetTokenCount(resultString, tokenCounter);
|
||||
}
|
||||
}
|
||||
|
||||
result.Add(resultString, resultTokenCount);
|
||||
|
||||
return (result, inputWasSplit);
|
||||
}
|
||||
|
||||
private static int GetTokenCount(string input, TokenCounter? tokenCounter) => tokenCounter is null ? GetDefaultTokenCount(input.Length) : tokenCounter(input);
|
||||
|
||||
private static int GetTokenCount(ReadOnlySpan<char> input, TokenCounter? tokenCounter) => tokenCounter is null ? GetDefaultTokenCount(input.Length) : tokenCounter(input.ToString());
|
||||
|
||||
private static string NormalizeLineEndings(string text) => text.Replace("\r\n", "\n").Replace('\r', '\n');
|
||||
|
||||
private static void ValidateMaxTokens(int maxTokens, string parameterName)
|
||||
{
|
||||
if (maxTokens <= 0)
|
||||
{
|
||||
throw new ArgumentOutOfRangeException(parameterName, "The maximum token count must be a positive number.");
|
||||
}
|
||||
}
|
||||
|
||||
private static int GetDefaultTokenCount(int length)
|
||||
{
|
||||
Debug.Assert(length >= 0);
|
||||
return length == 0 ? 0 : Math.Max(1, length >> 2);
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
using Microsoft.Extensions.Options;
|
||||
using Microsoft.SemanticKernel.Text;
|
||||
using SqlDatabaseVectorSearch.Services;
|
||||
using SqlDatabaseVectorSearch.Settings;
|
||||
using SqlDatabaseVectorSearch.TextChunkers.Implementations;
|
||||
|
||||
namespace SqlDatabaseVectorSearch.TextChunkers;
|
||||
|
||||
@@ -11,8 +11,8 @@ public class MarkdownTextChunker(TokenizerService tokenizerService, IOptions<App
|
||||
|
||||
public IList<string> Split(string text)
|
||||
{
|
||||
var lines = TextChunker.SplitMarkDownLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountEmbeddingTokens);
|
||||
var paragraphs = TextChunker.SplitMarkdownParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountEmbeddingTokens);
|
||||
var lines = PlainTextChunker.SplitMarkdownLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountEmbeddingTokens);
|
||||
var paragraphs = PlainTextChunker.SplitMarkdownParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountEmbeddingTokens);
|
||||
|
||||
return paragraphs;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
using Microsoft.Agents.AI.Workflows;
|
||||
|
||||
namespace SqlDatabaseVectorSearch.Workflows;
|
||||
|
||||
public partial class FormFileToEmbeddingRequestExecutor() : Executor(nameof(FormFileToEmbeddingRequestExecutor))
|
||||
{
|
||||
[MessageHandler]
|
||||
private ValueTask<EmbeddingRequest> HandleAsync(FormFileEmbeddingRequest request, IWorkflowContext context, CancellationToken cancellationToken)
|
||||
{
|
||||
// Note: file.ContentType is not 100% reliable (for example, for markdown file).
|
||||
var embeddingRequest = new EmbeddingRequest(request.File.OpenReadStream(), Path.GetFileName(request.File.FileName), MimeMapping.MimeUtility.GetMimeMapping(request.File.FileName), request.DocumentId);
|
||||
return ValueTask.FromResult(embeddingRequest);
|
||||
}
|
||||
}
|
||||
|
||||
public record class FormFileEmbeddingRequest(IFormFile File, Guid? DocumentId);
|
||||
|
||||
public record class EmbeddingRequest(Stream Content, string FileName, string ContentType, Guid? DocumentId);
|
||||
@@ -0,0 +1,41 @@
|
||||
using Microsoft.Agents.AI.Workflows;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.Extensions.Options;
|
||||
using SqlDatabaseVectorSearch.ContentDecoders;
|
||||
using SqlDatabaseVectorSearch.Models;
|
||||
using SqlDatabaseVectorSearch.Services;
|
||||
using SqlDatabaseVectorSearch.Settings;
|
||||
|
||||
namespace SqlDatabaseVectorSearch.Workflows;
|
||||
|
||||
public partial class GenerateEmbeddingExecutor(IServiceProvider serviceProvider, IEmbeddingGenerator<string, Embedding<float>> embeddingGenerator, TokenizerService tokenizerService, IOptions<AppSettings> appSettingsOptions, ILogger<VectorSearchService> logger) : Executor(nameof(GenerateEmbeddingExecutor))
|
||||
{
|
||||
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
||||
|
||||
[MessageHandler]
|
||||
private async ValueTask<EmbeddingResponse> HandleAsync(EmbeddingRequest request, IWorkflowContext context, CancellationToken cancellationToken)
|
||||
{
|
||||
// Extract the contents of the file.
|
||||
var decoder = serviceProvider.GetKeyedService<IContentDecoder>(request.ContentType) ?? throw new NotSupportedException($"Content type '{request.ContentType}' is not supported.");
|
||||
var chunks = await decoder.DecodeAsync(request.Content, request.ContentType, cancellationToken);
|
||||
var chunkContents = chunks.Select(p => p.Content).ToList();
|
||||
|
||||
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
|
||||
var tokenCount = tokenizerService.CountEmbeddingTokens(string.Join(" ", chunkContents));
|
||||
|
||||
// Process paragraphs in batches.
|
||||
var embeddings = new List<Embedding<float>>();
|
||||
foreach (var batch in chunkContents.Chunk(appSettings.EmbeddingBatchSize))
|
||||
{
|
||||
logger.LogDebug("Processing batch of {Count} chunks for embedding generation...", batch.Length);
|
||||
|
||||
// Generate embeddings for this batch.
|
||||
var batchEmbeddings = await embeddingGenerator.GenerateAsync(batch, cancellationToken: cancellationToken);
|
||||
embeddings.AddRange(batchEmbeddings);
|
||||
}
|
||||
|
||||
return new EmbeddingResponse(request, chunks, embeddings, tokenCount);
|
||||
}
|
||||
}
|
||||
|
||||
public record class EmbeddingResponse(EmbeddingRequest Request, IEnumerable<Chunk> Chunks, IEnumerable<Embedding<float>> Embeddings, int TokenCount);
|
||||
@@ -0,0 +1,58 @@
|
||||
using Microsoft.Agents.AI.Workflows;
|
||||
using Microsoft.Data.SqlTypes;
|
||||
using Microsoft.EntityFrameworkCore;
|
||||
using SqlDatabaseVectorSearch.Data;
|
||||
using SqlDatabaseVectorSearch.Services;
|
||||
using Entities = SqlDatabaseVectorSearch.Data.Entities;
|
||||
|
||||
namespace SqlDatabaseVectorSearch.Workflows;
|
||||
|
||||
public partial class StoreEmbeddingExecutor(ApplicationDbContext dbContext, DocumentService documentService, TokenizerService tokenizerService, TimeProvider timeProvider, ILogger<VectorSearchService> logger) : Executor(nameof(StoreEmbeddingExecutor))
|
||||
{
|
||||
[MessageHandler]
|
||||
private async ValueTask<StoreEmbeddingResponse> HandleAsync(EmbeddingResponse embeddingData, IWorkflowContext context, CancellationToken cancellationToken)
|
||||
{
|
||||
var strategy = dbContext.Database.CreateExecutionStrategy();
|
||||
var document = await strategy.ExecuteAsync(async (cancellationToken) =>
|
||||
{
|
||||
await dbContext.Database.BeginTransactionAsync(cancellationToken);
|
||||
|
||||
if (embeddingData.Request.DocumentId.HasValue)
|
||||
{
|
||||
// If the user is importing a document that already exists, delete the previous one.
|
||||
await documentService.DeleteAsync(embeddingData.Request.DocumentId.Value, cancellationToken);
|
||||
}
|
||||
|
||||
var document = new Entities.Document { Id = embeddingData.Request.DocumentId.GetValueOrDefault(), Name = embeddingData.Request.FileName, CreationDate = timeProvider.GetUtcNow() };
|
||||
dbContext.Documents.Add(document);
|
||||
|
||||
// Save the document chunks and the corresponding embedding in the database.
|
||||
foreach (var (index, embedding) in embeddingData.Embeddings.Index())
|
||||
{
|
||||
var chunk = embeddingData.Chunks.ElementAt(index);
|
||||
logger.LogDebug("Storing a chunk of {TokenCount} tokens.", tokenizerService.CountEmbeddingTokens(chunk.Content));
|
||||
|
||||
var documentChunk = new Entities.DocumentChunk
|
||||
{
|
||||
Document = document,
|
||||
Index = index,
|
||||
PageNumber = chunk.PageNumber,
|
||||
IndexOnPage = chunk.IndexOnPage,
|
||||
Content = chunk.Content,
|
||||
Embedding = new SqlVector<float>(embedding.Vector)
|
||||
};
|
||||
|
||||
dbContext.DocumentChunks.Add(documentChunk);
|
||||
}
|
||||
|
||||
await dbContext.SaveChangesAsync(cancellationToken);
|
||||
await dbContext.Database.CommitTransactionAsync(cancellationToken);
|
||||
|
||||
return document;
|
||||
}, cancellationToken);
|
||||
|
||||
return new(document.Id, embeddingData.TokenCount);
|
||||
}
|
||||
}
|
||||
|
||||
public record class StoreEmbeddingResponse(Guid DocumentId, int TokenCount);
|
||||
Reference in New Issue
Block a user