diff --git a/SqlDatabaseVectorSearch/ContentDecoders/TextContentDecoder.cs b/SqlDatabaseVectorSearch/ContentDecoders/TextContentDecoder.cs index 3235b8d..d03e32f 100644 --- a/SqlDatabaseVectorSearch/ContentDecoders/TextContentDecoder.cs +++ b/SqlDatabaseVectorSearch/ContentDecoders/TextContentDecoder.cs @@ -1,12 +1,17 @@ -namespace SqlDatabaseVectorSearch.ContentDecoders; +using SqlDatabaseVectorSearch.TextChunkers; -public class TextContentDecoder : IContentDecoder +namespace SqlDatabaseVectorSearch.ContentDecoders; + +public class TextContentDecoder(IServiceProvider serviceProvider) : IContentDecoder { public async Task> DecodeAsync(Stream stream, string contentType, CancellationToken cancellationToken = default) { + var textChunker = serviceProvider.GetRequiredKeyedService(contentType); + using var readStream = new StreamReader(stream); var content = await readStream.ReadToEndAsync(cancellationToken); - return [new(1, 0, content)]; + var paragraphs = textChunker.Split(content); + return paragraphs.Select((text, index) => new Chunk(1, index, text)).ToList(); } } diff --git a/SqlDatabaseVectorSearch/Endpoints/AskEndpoints.cs b/SqlDatabaseVectorSearch/Endpoints/AskEndpoints.cs index 1976fe7..5b236ff 100644 --- a/SqlDatabaseVectorSearch/Endpoints/AskEndpoints.cs +++ b/SqlDatabaseVectorSearch/Endpoints/AskEndpoints.cs @@ -23,7 +23,7 @@ public class AskEndpoints : IEndpointRouteHandlerBuilder endpoints.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService, CancellationToken cancellationToken, [Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) => { - async IAsyncEnumerable Stream() + async IAsyncEnumerable Stream() { // Requests a streaming response. var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate, cancellationToken); diff --git a/SqlDatabaseVectorSearch/Models/QuestionResponse.cs b/SqlDatabaseVectorSearch/Models/QuestionResponse.cs deleted file mode 100644 index 958f9d1..0000000 --- a/SqlDatabaseVectorSearch/Models/QuestionResponse.cs +++ /dev/null @@ -1,10 +0,0 @@ -namespace SqlDatabaseVectorSearch.Models; - -// Question and Answer can be null when using response streaming. -public record class QuestionResponse(string? OriginalQuestion, string? ReformulatedQuestion, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null) -{ - public QuestionResponse(string? token, StreamState streamState, TokenUsageResponse? tokenUsageResponse = null) - : this(null, null, token, streamState, tokenUsageResponse) - { - } -} \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Models/Response.cs b/SqlDatabaseVectorSearch/Models/Response.cs new file mode 100644 index 0000000..54921ba --- /dev/null +++ b/SqlDatabaseVectorSearch/Models/Response.cs @@ -0,0 +1,10 @@ +namespace SqlDatabaseVectorSearch.Models; + +// Question and Answer can be null when using response streaming. +public record class Response(string? OriginalQuestion, string? ReformulatedQuestion, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null) +{ + public Response(string? token, StreamState streamState, TokenUsageResponse? tokenUsageResponse = null) + : this(null, null, token, streamState, tokenUsageResponse) + { + } +} \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs index c3c7e63..d2819ac 100644 --- a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs +++ b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs @@ -23,7 +23,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb var paragraphs = await decoder.DecodeAsync(stream, contentType, cancellationToken); // We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis). - var tokenCount = tokenizerService.CountEmbeddingTokens(string.Join(string.Empty, paragraphs.Select(p => p.Content))); + var tokenCount = tokenizerService.CountEmbeddingTokens(string.Join(" ", paragraphs.Select(p => p.Content))); var strategy = dbContext.Database.CreateExecutionStrategy(); var document = await strategy.ExecuteAsync(async (cancellationToken) => @@ -59,7 +59,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb return new(document.Id, tokenCount); } - public async Task AskQuestionAsync(Question question, bool reformulate = true, CancellationToken cancellationToken = default) + public async Task AskQuestionAsync(Question question, bool reformulate = true, CancellationToken cancellationToken = default) { // It the user doesn't want to reforulate the question, CreateContextAsync returns the original one. var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken); @@ -69,7 +69,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb return new(question.Text, reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage)); } - public async IAsyncEnumerable AskStreamingAsync(Question question, bool reformulate = true, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable AskStreamingAsync(Question question, bool reformulate = true, [EnumeratorCancellation] CancellationToken cancellationToken = default) { // It the user doesn't want to reforulate the question, CreateContextAsync returns the original one. var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken);