Refactor response handling and content decoding

- Updated `TextContentDecoder` to use `ITextChunker` for paragraph splitting and return a list of `Chunk` objects.
- Changed return type of `Stream` method in `AskEndpoints.cs` from `IAsyncEnumerable<QuestionResponse>` to `IAsyncEnumerable<Response>`.
- Removed `QuestionResponse` class and introduced `Response` class to better handle streaming responses.
- Modified `AskQuestionAsync` and `AskStreamingAsync` methods in `VectorSearchService` to return `Response` instead of `QuestionResponse`, and adjusted token count calculation.
- Added namespace declaration in `Response.cs` and defined properties to align with new response structure.
This commit is contained in:
Marco Minerva
2025-06-04 10:22:15 +02:00
parent a7fef36b66
commit 2fc070d0aa
5 changed files with 22 additions and 17 deletions
@@ -1,12 +1,17 @@
namespace SqlDatabaseVectorSearch.ContentDecoders; using SqlDatabaseVectorSearch.TextChunkers;
public class TextContentDecoder : IContentDecoder namespace SqlDatabaseVectorSearch.ContentDecoders;
public class TextContentDecoder(IServiceProvider serviceProvider) : IContentDecoder
{ {
public async Task<IEnumerable<Chunk>> DecodeAsync(Stream stream, string contentType, CancellationToken cancellationToken = default) public async Task<IEnumerable<Chunk>> DecodeAsync(Stream stream, string contentType, CancellationToken cancellationToken = default)
{ {
var textChunker = serviceProvider.GetRequiredKeyedService<ITextChunker>(contentType);
using var readStream = new StreamReader(stream); using var readStream = new StreamReader(stream);
var content = await readStream.ReadToEndAsync(cancellationToken); var content = await readStream.ReadToEndAsync(cancellationToken);
return [new(1, 0, content)]; var paragraphs = textChunker.Split(content);
return paragraphs.Select((text, index) => new Chunk(1, index, text)).ToList();
} }
} }
@@ -23,7 +23,7 @@ public class AskEndpoints : IEndpointRouteHandlerBuilder
endpoints.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService, CancellationToken cancellationToken, endpoints.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService, CancellationToken cancellationToken,
[Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) => [Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) =>
{ {
async IAsyncEnumerable<QuestionResponse> Stream() async IAsyncEnumerable<Response> Stream()
{ {
// Requests a streaming response. // Requests a streaming response.
var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate, cancellationToken); var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate, cancellationToken);
@@ -1,10 +0,0 @@
namespace SqlDatabaseVectorSearch.Models;
// Question and Answer can be null when using response streaming.
public record class QuestionResponse(string? OriginalQuestion, string? ReformulatedQuestion, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null)
{
public QuestionResponse(string? token, StreamState streamState, TokenUsageResponse? tokenUsageResponse = null)
: this(null, null, token, streamState, tokenUsageResponse)
{
}
}
@@ -0,0 +1,10 @@
namespace SqlDatabaseVectorSearch.Models;
// Question and Answer can be null when using response streaming.
public record class Response(string? OriginalQuestion, string? ReformulatedQuestion, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null)
{
public Response(string? token, StreamState streamState, TokenUsageResponse? tokenUsageResponse = null)
: this(null, null, token, streamState, tokenUsageResponse)
{
}
}
@@ -23,7 +23,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
var paragraphs = await decoder.DecodeAsync(stream, contentType, cancellationToken); var paragraphs = await decoder.DecodeAsync(stream, contentType, cancellationToken);
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis). // We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
var tokenCount = tokenizerService.CountEmbeddingTokens(string.Join(string.Empty, paragraphs.Select(p => p.Content))); var tokenCount = tokenizerService.CountEmbeddingTokens(string.Join(" ", paragraphs.Select(p => p.Content)));
var strategy = dbContext.Database.CreateExecutionStrategy(); var strategy = dbContext.Database.CreateExecutionStrategy();
var document = await strategy.ExecuteAsync(async (cancellationToken) => var document = await strategy.ExecuteAsync(async (cancellationToken) =>
@@ -59,7 +59,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
return new(document.Id, tokenCount); return new(document.Id, tokenCount);
} }
public async Task<QuestionResponse> AskQuestionAsync(Question question, bool reformulate = true, CancellationToken cancellationToken = default) public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true, CancellationToken cancellationToken = default)
{ {
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one. // It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken); var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken);
@@ -69,7 +69,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
return new(question.Text, reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage)); return new(question.Text, reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage));
} }
public async IAsyncEnumerable<QuestionResponse> AskStreamingAsync(Question question, bool reformulate = true, [EnumeratorCancellation] CancellationToken cancellationToken = default) public async IAsyncEnumerable<Response> AskStreamingAsync(Question question, bool reformulate = true, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{ {
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one. // It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken); var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken);