Refactor response types and enhance token usage handling

This commit is contained in:
Marco Minerva
2025-01-30 12:56:33 +01:00
parent dd58c547d0
commit 8e06979993
10 changed files with 95 additions and 40 deletions
@@ -15,12 +15,15 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
{
private readonly AppSettings appSettings = appSettingsOptions.Value;
public async Task<Guid> ImportAsync(Stream stream, string name, string contentType, Guid? documentId)
public async Task<ImportDocumentResponse> ImportAsync(Stream stream, string name, string contentType, Guid? documentId)
{
// Extract the contents of the file.
var decoder = serviceProvider.GetKeyedService<IContentDecoder>(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported.");
var content = await decoder.DecodeAsync(stream, contentType);
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
var tokenCount = tokenizerService.CountTokens(content);
await dbContext.Database.BeginTransactionAsync();
if (documentId.HasValue)
@@ -49,7 +52,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
await dbContext.SaveChangesAsync();
await dbContext.Database.CommitTransactionAsync();
return document.Id;
return new(document.Id, tokenCount);
}
public async Task<IEnumerable<Document>> GetDocumentsAsync()
@@ -82,40 +85,45 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
public Task DeleteDocumentAsync(Guid documentId)
=> dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync();
public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true)
public async Task<QuestionResponse> AskQuestionAsync(Question question, bool reformulate = true)
{
var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate);
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate);
var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion);
return new Response(reformulatedQuestion, answer);
var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
return new(reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage));
}
public async IAsyncEnumerable<Response> AskStreamingAsync(Question question, bool reformulate = true)
public async IAsyncEnumerable<QuestionResponse> AskStreamingAsync(Question question, bool reformulate = true)
{
var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate);
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate);
var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion);
var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
// The first message contains the original question.
yield return new Response(reformulatedQuestion, null, StreamState.Start);
// The first message contains the question and the corresponding token usage (if reformulated).
yield return new(reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null));
// Return each token as a partial response.
await foreach (var token in answerStream)
await foreach (var (token, tokenUsage) in answerStream)
{
yield return new Response(null, token, StreamState.Append);
// Token usage is contained in the last message.
yield return new(null, token, StreamState.Append, tokenUsage is not null ? new(null, null, tokenUsage) : null);
}
// The last message tells the client that the stream has ended.
yield return new Response(null, null, StreamState.End);
yield return new(null, null, StreamState.End);
}
private async Task<(string Question, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate = true)
private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate = true)
{
// Reformulate the following question taking into account the context of the chat to perform keyword search and embeddings:
var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : question.Text;
// Reformulate the question taking into account the context of the chat to perform keyword search and embeddings.
var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : new(question.Text);
// Perform Vector Search on SQL Database.
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion);
var embeddingTokenCount = tokenizerService.CountTokens(reformulatedQuestion.Text!);
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!);
var chunks = await dbContext.DocumentChunks
.OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray()))
@@ -123,6 +131,6 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
.Take(appSettings.MaxRelevantChunks)
.ToListAsync();
return (reformulatedQuestion, chunks);
return (reformulatedQuestion, embeddingTokenCount, chunks);
}
}