Refactor response types and enhance token usage handling

This commit is contained in:
Marco Minerva
2025-01-30 12:56:33 +01:00
parent dd58c547d0
commit 8e06979993
10 changed files with 95 additions and 40 deletions
@@ -0,0 +1,3 @@
namespace SqlDatabaseVectorSearch.Models;
public record class ChatResponse(string? Text, TokenUsage? TokenUsage = null);
@@ -0,0 +1,3 @@
namespace SqlDatabaseVectorSearch.Models;
public record class ImportDocumentResponse(Guid DocumentId, int TotalTokenCount);
@@ -0,0 +1,4 @@
namespace SqlDatabaseVectorSearch.Models;
// Question and Asnwer can be null when using response streaming.
public record class QuestionResponse(string? Question, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null);
@@ -1,4 +0,0 @@
namespace SqlDatabaseVectorSearch.Models;
// Question and Asnwer can be null when using response streaming.
public record class Response(string? Question, string? Answer, StreamState? StreamState = null);
@@ -0,0 +1,6 @@
namespace SqlDatabaseVectorSearch.Models;
public record class TokenUsage(int InputTokenCount, int OutputTokenCount)
{
public int TotalTokenCount => InputTokenCount + OutputTokenCount;
}
@@ -0,0 +1,3 @@
namespace SqlDatabaseVectorSearch.Models;
public record class TokenUsageResponse(TokenUsage? Reformulation, int? EmbeddingTokenCount, TokenUsage? Question);
@@ -1,3 +0,0 @@
namespace SqlDatabaseVectorSearch.Models;
public record class UploadDocumentResponse(Guid DocumentId);
+8 -8
View File
@@ -19,6 +19,11 @@ builder.Configuration.AddJsonFile("appsettings.local.json", optional: true, relo
var aiSettings = builder.Services.ConfigureAndGet<AzureOpenAISettings>(builder.Configuration, "AzureOpenAI")!;
var appSettings = builder.Services.ConfigureAndGet<AppSettings>(builder.Configuration, nameof(AppSettings))!;
builder.Services.ConfigureHttpJsonOptions(options =>
{
options.SerializerOptions.Converters.Add(new JsonStringEnumConverter());
});
builder.Services.AddSingleton(TimeProvider.System);
builder.Services.AddSqlServer<ApplicationDbContext>(builder.Configuration.GetConnectionString("SqlConnection"), options =>
@@ -56,11 +61,6 @@ builder.Services.AddKeyedSingleton<IContentDecoder, PdfContentDecoder>(MediaType
builder.Services.AddKeyedSingleton<IContentDecoder, DocxContentDecoder>("application/vnd.openxmlformats-officedocument.wordprocessingml.document");
builder.Services.AddKeyedSingleton<IContentDecoder, TextContentDecoder>(MediaTypeNames.Text.Plain);
builder.Services.ConfigureHttpJsonOptions(options =>
{
options.SerializerOptions.Converters.Add(new JsonStringEnumConverter());
});
builder.Services.AddOpenApi(options =>
{
options.RemoveServerList();
@@ -127,9 +127,9 @@ documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchServi
[Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) =>
{
using var stream = file.OpenReadStream();
documentId = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId);
var response = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId);
return TypedResults.Ok(new UploadDocumentResponse(documentId.Value));
return TypedResults.Ok(response);
})
.DisableAntiforgery()
.ProducesProblem(StatusCodes.Status400BadRequest)
@@ -157,7 +157,7 @@ app.MapPost("/api/ask", async (Question question, VectorSearchService vectorSear
app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService,
[Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) =>
{
async IAsyncEnumerable<Response> Stream()
async IAsyncEnumerable<QuestionResponse> Stream()
{
// Requests a streaming response.
var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate);
@@ -3,6 +3,8 @@ using Microsoft.Extensions.Caching.Hybrid;
using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.AzureOpenAI;
using OpenAI.Chat;
using SqlDatabaseVectorSearch.Models;
using SqlDatabaseVectorSearch.Settings;
namespace SqlDatabaseVectorSearch.Services;
@@ -11,7 +13,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
{
private readonly AppSettings appSettings = appSettingsOptions.Value;
public async Task<string> CreateQuestionAsync(Guid conversationId, string question)
public async Task<ChatResponse> CreateQuestionAsync(Guid conversationId, string question)
{
var chat = await GetChatHistoryAsync(conversationId);
@@ -31,10 +33,12 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
await UpdateCacheAsync(conversationId, chat);
return reformulatedQuestion.Content!;
var tokenUsage = GetTokenUsage(reformulatedQuestion);
return new(reformulatedQuestion.Content!, tokenUsage);
}
public async Task<string> AskQuestionAsync(Guid conversationId, IEnumerable<string> chunks, string question)
public async Task<ChatResponse> AskQuestionAsync(Guid conversationId, IEnumerable<string> chunks, string question)
{
var chat = CreateChatAsync(chunks, question);
@@ -46,10 +50,12 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
// Add question and answer to the chat history.
await SetChatHistoryAsync(conversationId, question, answer.Content!);
return answer.Content!;
var tokenUsage = GetTokenUsage(answer);
return new(answer.Content!, tokenUsage);
}
public async IAsyncEnumerable<string> AskStreamingAsync(Guid conversationId, IEnumerable<string> chunks, string question)
public async IAsyncEnumerable<ChatResponse> AskStreamingAsync(Guid conversationId, IEnumerable<string> chunks, string question)
{
var chat = CreateChatAsync(chunks, question);
@@ -61,15 +67,44 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
{
if (!string.IsNullOrEmpty(token.Content))
{
yield return token.Content;
yield return new(token.Content);
answer.Append(token.Content);
}
else if (token.Content is null)
{
// Token usage is returned in the last message, when the Content is null.
var tokenUsage = GetTokenUsage(token);
if (tokenUsage is not null)
{
yield return new(null, tokenUsage);
}
}
}
// Add question and answer to the chat history.
await SetChatHistoryAsync(conversationId, question, answer.ToString());
}
private static TokenUsage? GetTokenUsage(Microsoft.SemanticKernel.ChatMessageContent message)
{
if (message.InnerContent is ChatCompletion content && content.Usage is not null)
{
return new(content.Usage.InputTokenCount, content.Usage.OutputTokenCount);
}
return null;
}
private static TokenUsage? GetTokenUsage(Microsoft.SemanticKernel.StreamingChatMessageContent message)
{
if (message.InnerContent is StreamingChatCompletionUpdate content && content.Usage is not null)
{
return new(content.Usage.InputTokenCount, content.Usage.OutputTokenCount);
}
return null;
}
private ChatHistory CreateChatAsync(IEnumerable<string> chunks, string question)
{
var chat = new ChatHistory("""
@@ -15,12 +15,15 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
{
private readonly AppSettings appSettings = appSettingsOptions.Value;
public async Task<Guid> ImportAsync(Stream stream, string name, string contentType, Guid? documentId)
public async Task<ImportDocumentResponse> ImportAsync(Stream stream, string name, string contentType, Guid? documentId)
{
// Extract the contents of the file.
var decoder = serviceProvider.GetKeyedService<IContentDecoder>(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported.");
var content = await decoder.DecodeAsync(stream, contentType);
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
var tokenCount = tokenizerService.CountTokens(content);
await dbContext.Database.BeginTransactionAsync();
if (documentId.HasValue)
@@ -49,7 +52,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
await dbContext.SaveChangesAsync();
await dbContext.Database.CommitTransactionAsync();
return document.Id;
return new(document.Id, tokenCount);
}
public async Task<IEnumerable<Document>> GetDocumentsAsync()
@@ -82,40 +85,45 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
public Task DeleteDocumentAsync(Guid documentId)
=> dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync();
public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true)
public async Task<QuestionResponse> AskQuestionAsync(Question question, bool reformulate = true)
{
var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate);
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate);
var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion);
return new Response(reformulatedQuestion, answer);
var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
return new(reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage));
}
public async IAsyncEnumerable<Response> AskStreamingAsync(Question question, bool reformulate = true)
public async IAsyncEnumerable<QuestionResponse> AskStreamingAsync(Question question, bool reformulate = true)
{
var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate);
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate);
var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion);
var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
// The first message contains the original question.
yield return new Response(reformulatedQuestion, null, StreamState.Start);
// The first message contains the question and the corresponding token usage (if reformulated).
yield return new(reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null));
// Return each token as a partial response.
await foreach (var token in answerStream)
await foreach (var (token, tokenUsage) in answerStream)
{
yield return new Response(null, token, StreamState.Append);
// Token usage is contained in the last message.
yield return new(null, token, StreamState.Append, tokenUsage is not null ? new(null, null, tokenUsage) : null);
}
// The last message tells the client that the stream has ended.
yield return new Response(null, null, StreamState.End);
yield return new(null, null, StreamState.End);
}
private async Task<(string Question, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate = true)
private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate = true)
{
// Reformulate the following question taking into account the context of the chat to perform keyword search and embeddings:
var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : question.Text;
// Reformulate the question taking into account the context of the chat to perform keyword search and embeddings.
var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : new(question.Text);
// Perform Vector Search on SQL Database.
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion);
var embeddingTokenCount = tokenizerService.CountTokens(reformulatedQuestion.Text!);
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!);
var chunks = await dbContext.DocumentChunks
.OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray()))
@@ -123,6 +131,6 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
.Take(appSettings.MaxRelevantChunks)
.ToListAsync();
return (reformulatedQuestion, chunks);
return (reformulatedQuestion, embeddingTokenCount, chunks);
}
}