Refactor response types and enhance token usage handling

This commit is contained in:
Marco Minerva
2025-01-30 12:56:33 +01:00
parent dd58c547d0
commit 8e06979993
10 changed files with 95 additions and 40 deletions
@@ -0,0 +1,3 @@
namespace SqlDatabaseVectorSearch.Models;
public record class ChatResponse(string? Text, TokenUsage? TokenUsage = null);
@@ -0,0 +1,3 @@
namespace SqlDatabaseVectorSearch.Models;
public record class ImportDocumentResponse(Guid DocumentId, int TotalTokenCount);
@@ -0,0 +1,4 @@
namespace SqlDatabaseVectorSearch.Models;
// Question and Asnwer can be null when using response streaming.
public record class QuestionResponse(string? Question, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null);
@@ -1,4 +0,0 @@
namespace SqlDatabaseVectorSearch.Models;
// Question and Asnwer can be null when using response streaming.
public record class Response(string? Question, string? Answer, StreamState? StreamState = null);
@@ -0,0 +1,6 @@
namespace SqlDatabaseVectorSearch.Models;
public record class TokenUsage(int InputTokenCount, int OutputTokenCount)
{
public int TotalTokenCount => InputTokenCount + OutputTokenCount;
}
@@ -0,0 +1,3 @@
namespace SqlDatabaseVectorSearch.Models;
public record class TokenUsageResponse(TokenUsage? Reformulation, int? EmbeddingTokenCount, TokenUsage? Question);
@@ -1,3 +0,0 @@
namespace SqlDatabaseVectorSearch.Models;
public record class UploadDocumentResponse(Guid DocumentId);
+8 -8
View File
@@ -19,6 +19,11 @@ builder.Configuration.AddJsonFile("appsettings.local.json", optional: true, relo
var aiSettings = builder.Services.ConfigureAndGet<AzureOpenAISettings>(builder.Configuration, "AzureOpenAI")!; var aiSettings = builder.Services.ConfigureAndGet<AzureOpenAISettings>(builder.Configuration, "AzureOpenAI")!;
var appSettings = builder.Services.ConfigureAndGet<AppSettings>(builder.Configuration, nameof(AppSettings))!; var appSettings = builder.Services.ConfigureAndGet<AppSettings>(builder.Configuration, nameof(AppSettings))!;
builder.Services.ConfigureHttpJsonOptions(options =>
{
options.SerializerOptions.Converters.Add(new JsonStringEnumConverter());
});
builder.Services.AddSingleton(TimeProvider.System); builder.Services.AddSingleton(TimeProvider.System);
builder.Services.AddSqlServer<ApplicationDbContext>(builder.Configuration.GetConnectionString("SqlConnection"), options => builder.Services.AddSqlServer<ApplicationDbContext>(builder.Configuration.GetConnectionString("SqlConnection"), options =>
@@ -56,11 +61,6 @@ builder.Services.AddKeyedSingleton<IContentDecoder, PdfContentDecoder>(MediaType
builder.Services.AddKeyedSingleton<IContentDecoder, DocxContentDecoder>("application/vnd.openxmlformats-officedocument.wordprocessingml.document"); builder.Services.AddKeyedSingleton<IContentDecoder, DocxContentDecoder>("application/vnd.openxmlformats-officedocument.wordprocessingml.document");
builder.Services.AddKeyedSingleton<IContentDecoder, TextContentDecoder>(MediaTypeNames.Text.Plain); builder.Services.AddKeyedSingleton<IContentDecoder, TextContentDecoder>(MediaTypeNames.Text.Plain);
builder.Services.ConfigureHttpJsonOptions(options =>
{
options.SerializerOptions.Converters.Add(new JsonStringEnumConverter());
});
builder.Services.AddOpenApi(options => builder.Services.AddOpenApi(options =>
{ {
options.RemoveServerList(); options.RemoveServerList();
@@ -127,9 +127,9 @@ documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchServi
[Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) => [Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) =>
{ {
using var stream = file.OpenReadStream(); using var stream = file.OpenReadStream();
documentId = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId); var response = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId);
return TypedResults.Ok(new UploadDocumentResponse(documentId.Value)); return TypedResults.Ok(response);
}) })
.DisableAntiforgery() .DisableAntiforgery()
.ProducesProblem(StatusCodes.Status400BadRequest) .ProducesProblem(StatusCodes.Status400BadRequest)
@@ -157,7 +157,7 @@ app.MapPost("/api/ask", async (Question question, VectorSearchService vectorSear
app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService, app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService,
[Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) => [Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) =>
{ {
async IAsyncEnumerable<Response> Stream() async IAsyncEnumerable<QuestionResponse> Stream()
{ {
// Requests a streaming response. // Requests a streaming response.
var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate); var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate);
@@ -3,6 +3,8 @@ using Microsoft.Extensions.Caching.Hybrid;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.AzureOpenAI; using Microsoft.SemanticKernel.Connectors.AzureOpenAI;
using OpenAI.Chat;
using SqlDatabaseVectorSearch.Models;
using SqlDatabaseVectorSearch.Settings; using SqlDatabaseVectorSearch.Settings;
namespace SqlDatabaseVectorSearch.Services; namespace SqlDatabaseVectorSearch.Services;
@@ -11,7 +13,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
{ {
private readonly AppSettings appSettings = appSettingsOptions.Value; private readonly AppSettings appSettings = appSettingsOptions.Value;
public async Task<string> CreateQuestionAsync(Guid conversationId, string question) public async Task<ChatResponse> CreateQuestionAsync(Guid conversationId, string question)
{ {
var chat = await GetChatHistoryAsync(conversationId); var chat = await GetChatHistoryAsync(conversationId);
@@ -31,10 +33,12 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
await UpdateCacheAsync(conversationId, chat); await UpdateCacheAsync(conversationId, chat);
return reformulatedQuestion.Content!; var tokenUsage = GetTokenUsage(reformulatedQuestion);
return new(reformulatedQuestion.Content!, tokenUsage);
} }
public async Task<string> AskQuestionAsync(Guid conversationId, IEnumerable<string> chunks, string question) public async Task<ChatResponse> AskQuestionAsync(Guid conversationId, IEnumerable<string> chunks, string question)
{ {
var chat = CreateChatAsync(chunks, question); var chat = CreateChatAsync(chunks, question);
@@ -46,10 +50,12 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
// Add question and answer to the chat history. // Add question and answer to the chat history.
await SetChatHistoryAsync(conversationId, question, answer.Content!); await SetChatHistoryAsync(conversationId, question, answer.Content!);
return answer.Content!; var tokenUsage = GetTokenUsage(answer);
return new(answer.Content!, tokenUsage);
} }
public async IAsyncEnumerable<string> AskStreamingAsync(Guid conversationId, IEnumerable<string> chunks, string question) public async IAsyncEnumerable<ChatResponse> AskStreamingAsync(Guid conversationId, IEnumerable<string> chunks, string question)
{ {
var chat = CreateChatAsync(chunks, question); var chat = CreateChatAsync(chunks, question);
@@ -61,15 +67,44 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
{ {
if (!string.IsNullOrEmpty(token.Content)) if (!string.IsNullOrEmpty(token.Content))
{ {
yield return token.Content; yield return new(token.Content);
answer.Append(token.Content); answer.Append(token.Content);
} }
else if (token.Content is null)
{
// Token usage is returned in the last message, when the Content is null.
var tokenUsage = GetTokenUsage(token);
if (tokenUsage is not null)
{
yield return new(null, tokenUsage);
}
}
} }
// Add question and answer to the chat history. // Add question and answer to the chat history.
await SetChatHistoryAsync(conversationId, question, answer.ToString()); await SetChatHistoryAsync(conversationId, question, answer.ToString());
} }
private static TokenUsage? GetTokenUsage(Microsoft.SemanticKernel.ChatMessageContent message)
{
if (message.InnerContent is ChatCompletion content && content.Usage is not null)
{
return new(content.Usage.InputTokenCount, content.Usage.OutputTokenCount);
}
return null;
}
private static TokenUsage? GetTokenUsage(Microsoft.SemanticKernel.StreamingChatMessageContent message)
{
if (message.InnerContent is StreamingChatCompletionUpdate content && content.Usage is not null)
{
return new(content.Usage.InputTokenCount, content.Usage.OutputTokenCount);
}
return null;
}
private ChatHistory CreateChatAsync(IEnumerable<string> chunks, string question) private ChatHistory CreateChatAsync(IEnumerable<string> chunks, string question)
{ {
var chat = new ChatHistory(""" var chat = new ChatHistory("""
@@ -15,12 +15,15 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
{ {
private readonly AppSettings appSettings = appSettingsOptions.Value; private readonly AppSettings appSettings = appSettingsOptions.Value;
public async Task<Guid> ImportAsync(Stream stream, string name, string contentType, Guid? documentId) public async Task<ImportDocumentResponse> ImportAsync(Stream stream, string name, string contentType, Guid? documentId)
{ {
// Extract the contents of the file. // Extract the contents of the file.
var decoder = serviceProvider.GetKeyedService<IContentDecoder>(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported."); var decoder = serviceProvider.GetKeyedService<IContentDecoder>(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported.");
var content = await decoder.DecodeAsync(stream, contentType); var content = await decoder.DecodeAsync(stream, contentType);
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
var tokenCount = tokenizerService.CountTokens(content);
await dbContext.Database.BeginTransactionAsync(); await dbContext.Database.BeginTransactionAsync();
if (documentId.HasValue) if (documentId.HasValue)
@@ -49,7 +52,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
await dbContext.SaveChangesAsync(); await dbContext.SaveChangesAsync();
await dbContext.Database.CommitTransactionAsync(); await dbContext.Database.CommitTransactionAsync();
return document.Id; return new(document.Id, tokenCount);
} }
public async Task<IEnumerable<Document>> GetDocumentsAsync() public async Task<IEnumerable<Document>> GetDocumentsAsync()
@@ -82,40 +85,45 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
public Task DeleteDocumentAsync(Guid documentId) public Task DeleteDocumentAsync(Guid documentId)
=> dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync(); => dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync();
public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true) public async Task<QuestionResponse> AskQuestionAsync(Question question, bool reformulate = true)
{ {
var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate); // It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate);
var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion); var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
return new Response(reformulatedQuestion, answer);
return new(reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage));
} }
public async IAsyncEnumerable<Response> AskStreamingAsync(Question question, bool reformulate = true) public async IAsyncEnumerable<QuestionResponse> AskStreamingAsync(Question question, bool reformulate = true)
{ {
var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate); // It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate);
var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion); var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
// The first message contains the original question. // The first message contains the question and the corresponding token usage (if reformulated).
yield return new Response(reformulatedQuestion, null, StreamState.Start); yield return new(reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null));
// Return each token as a partial response. // Return each token as a partial response.
await foreach (var token in answerStream) await foreach (var (token, tokenUsage) in answerStream)
{ {
yield return new Response(null, token, StreamState.Append); // Token usage is contained in the last message.
yield return new(null, token, StreamState.Append, tokenUsage is not null ? new(null, null, tokenUsage) : null);
} }
// The last message tells the client that the stream has ended. // The last message tells the client that the stream has ended.
yield return new Response(null, null, StreamState.End); yield return new(null, null, StreamState.End);
} }
private async Task<(string Question, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate = true) private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate = true)
{ {
// Reformulate the following question taking into account the context of the chat to perform keyword search and embeddings: // Reformulate the question taking into account the context of the chat to perform keyword search and embeddings.
var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : question.Text; var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : new(question.Text);
// Perform Vector Search on SQL Database. // Perform Vector Search on SQL Database.
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion); var embeddingTokenCount = tokenizerService.CountTokens(reformulatedQuestion.Text!);
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!);
var chunks = await dbContext.DocumentChunks var chunks = await dbContext.DocumentChunks
.OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray())) .OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray()))
@@ -123,6 +131,6 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
.Take(appSettings.MaxRelevantChunks) .Take(appSettings.MaxRelevantChunks)
.ToListAsync(); .ToListAsync();
return (reformulatedQuestion, chunks); return (reformulatedQuestion, embeddingTokenCount, chunks);
} }
} }