mirror of
https://github.com/marcominerva/SqlDatabaseVectorSearch.git
synced 2026-06-20 12:23:10 +00:00
Refactor response handling and content decoding
- Updated `TextContentDecoder` to use `ITextChunker` for paragraph splitting and return a list of `Chunk` objects. - Changed return type of `Stream` method in `AskEndpoints.cs` from `IAsyncEnumerable<QuestionResponse>` to `IAsyncEnumerable<Response>`. - Removed `QuestionResponse` class and introduced `Response` class to better handle streaming responses. - Modified `AskQuestionAsync` and `AskStreamingAsync` methods in `VectorSearchService` to return `Response` instead of `QuestionResponse`, and adjusted token count calculation. - Added namespace declaration in `Response.cs` and defined properties to align with new response structure.
This commit is contained in:
@@ -1,12 +1,17 @@
|
|||||||
namespace SqlDatabaseVectorSearch.ContentDecoders;
|
using SqlDatabaseVectorSearch.TextChunkers;
|
||||||
|
|
||||||
public class TextContentDecoder : IContentDecoder
|
namespace SqlDatabaseVectorSearch.ContentDecoders;
|
||||||
|
|
||||||
|
public class TextContentDecoder(IServiceProvider serviceProvider) : IContentDecoder
|
||||||
{
|
{
|
||||||
public async Task<IEnumerable<Chunk>> DecodeAsync(Stream stream, string contentType, CancellationToken cancellationToken = default)
|
public async Task<IEnumerable<Chunk>> DecodeAsync(Stream stream, string contentType, CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
|
var textChunker = serviceProvider.GetRequiredKeyedService<ITextChunker>(contentType);
|
||||||
|
|
||||||
using var readStream = new StreamReader(stream);
|
using var readStream = new StreamReader(stream);
|
||||||
var content = await readStream.ReadToEndAsync(cancellationToken);
|
var content = await readStream.ReadToEndAsync(cancellationToken);
|
||||||
|
|
||||||
return [new(1, 0, content)];
|
var paragraphs = textChunker.Split(content);
|
||||||
|
return paragraphs.Select((text, index) => new Chunk(1, index, text)).ToList();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ public class AskEndpoints : IEndpointRouteHandlerBuilder
|
|||||||
endpoints.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService, CancellationToken cancellationToken,
|
endpoints.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService, CancellationToken cancellationToken,
|
||||||
[Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) =>
|
[Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) =>
|
||||||
{
|
{
|
||||||
async IAsyncEnumerable<QuestionResponse> Stream()
|
async IAsyncEnumerable<Response> Stream()
|
||||||
{
|
{
|
||||||
// Requests a streaming response.
|
// Requests a streaming response.
|
||||||
var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate, cancellationToken);
|
var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate, cancellationToken);
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
namespace SqlDatabaseVectorSearch.Models;
|
|
||||||
|
|
||||||
// Question and Answer can be null when using response streaming.
|
|
||||||
public record class QuestionResponse(string? OriginalQuestion, string? ReformulatedQuestion, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null)
|
|
||||||
{
|
|
||||||
public QuestionResponse(string? token, StreamState streamState, TokenUsageResponse? tokenUsageResponse = null)
|
|
||||||
: this(null, null, token, streamState, tokenUsageResponse)
|
|
||||||
{
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
namespace SqlDatabaseVectorSearch.Models;
|
||||||
|
|
||||||
|
// Question and Answer can be null when using response streaming.
|
||||||
|
public record class Response(string? OriginalQuestion, string? ReformulatedQuestion, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null)
|
||||||
|
{
|
||||||
|
public Response(string? token, StreamState streamState, TokenUsageResponse? tokenUsageResponse = null)
|
||||||
|
: this(null, null, token, streamState, tokenUsageResponse)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -23,7 +23,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
|||||||
var paragraphs = await decoder.DecodeAsync(stream, contentType, cancellationToken);
|
var paragraphs = await decoder.DecodeAsync(stream, contentType, cancellationToken);
|
||||||
|
|
||||||
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
|
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
|
||||||
var tokenCount = tokenizerService.CountEmbeddingTokens(string.Join(string.Empty, paragraphs.Select(p => p.Content)));
|
var tokenCount = tokenizerService.CountEmbeddingTokens(string.Join(" ", paragraphs.Select(p => p.Content)));
|
||||||
|
|
||||||
var strategy = dbContext.Database.CreateExecutionStrategy();
|
var strategy = dbContext.Database.CreateExecutionStrategy();
|
||||||
var document = await strategy.ExecuteAsync(async (cancellationToken) =>
|
var document = await strategy.ExecuteAsync(async (cancellationToken) =>
|
||||||
@@ -59,7 +59,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
|||||||
return new(document.Id, tokenCount);
|
return new(document.Id, tokenCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<QuestionResponse> AskQuestionAsync(Question question, bool reformulate = true, CancellationToken cancellationToken = default)
|
public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true, CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
|
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
|
||||||
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken);
|
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken);
|
||||||
@@ -69,7 +69,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
|||||||
return new(question.Text, reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage));
|
return new(question.Text, reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage));
|
||||||
}
|
}
|
||||||
|
|
||||||
public async IAsyncEnumerable<QuestionResponse> AskStreamingAsync(Question question, bool reformulate = true, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
public async IAsyncEnumerable<Response> AskStreamingAsync(Question question, bool reformulate = true, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
|
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
|
||||||
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken);
|
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken);
|
||||||
|
|||||||
Reference in New Issue
Block a user