mirror of
https://github.com/marcominerva/SqlDatabaseVectorSearch.git
synced 2026-06-20 12:23:10 +00:00
Refactor response types and enhance token usage handling
This commit is contained in:
@@ -0,0 +1,3 @@
|
|||||||
|
namespace SqlDatabaseVectorSearch.Models;
|
||||||
|
|
||||||
|
public record class ChatResponse(string? Text, TokenUsage? TokenUsage = null);
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
namespace SqlDatabaseVectorSearch.Models;
|
||||||
|
|
||||||
|
public record class ImportDocumentResponse(Guid DocumentId, int TotalTokenCount);
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
namespace SqlDatabaseVectorSearch.Models;
|
||||||
|
|
||||||
|
// Question and Asnwer can be null when using response streaming.
|
||||||
|
public record class QuestionResponse(string? Question, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null);
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
namespace SqlDatabaseVectorSearch.Models;
|
|
||||||
|
|
||||||
// Question and Asnwer can be null when using response streaming.
|
|
||||||
public record class Response(string? Question, string? Answer, StreamState? StreamState = null);
|
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
namespace SqlDatabaseVectorSearch.Models;
|
||||||
|
|
||||||
|
public record class TokenUsage(int InputTokenCount, int OutputTokenCount)
|
||||||
|
{
|
||||||
|
public int TotalTokenCount => InputTokenCount + OutputTokenCount;
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
namespace SqlDatabaseVectorSearch.Models;
|
||||||
|
|
||||||
|
public record class TokenUsageResponse(TokenUsage? Reformulation, int? EmbeddingTokenCount, TokenUsage? Question);
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
namespace SqlDatabaseVectorSearch.Models;
|
|
||||||
|
|
||||||
public record class UploadDocumentResponse(Guid DocumentId);
|
|
||||||
@@ -19,6 +19,11 @@ builder.Configuration.AddJsonFile("appsettings.local.json", optional: true, relo
|
|||||||
var aiSettings = builder.Services.ConfigureAndGet<AzureOpenAISettings>(builder.Configuration, "AzureOpenAI")!;
|
var aiSettings = builder.Services.ConfigureAndGet<AzureOpenAISettings>(builder.Configuration, "AzureOpenAI")!;
|
||||||
var appSettings = builder.Services.ConfigureAndGet<AppSettings>(builder.Configuration, nameof(AppSettings))!;
|
var appSettings = builder.Services.ConfigureAndGet<AppSettings>(builder.Configuration, nameof(AppSettings))!;
|
||||||
|
|
||||||
|
builder.Services.ConfigureHttpJsonOptions(options =>
|
||||||
|
{
|
||||||
|
options.SerializerOptions.Converters.Add(new JsonStringEnumConverter());
|
||||||
|
});
|
||||||
|
|
||||||
builder.Services.AddSingleton(TimeProvider.System);
|
builder.Services.AddSingleton(TimeProvider.System);
|
||||||
|
|
||||||
builder.Services.AddSqlServer<ApplicationDbContext>(builder.Configuration.GetConnectionString("SqlConnection"), options =>
|
builder.Services.AddSqlServer<ApplicationDbContext>(builder.Configuration.GetConnectionString("SqlConnection"), options =>
|
||||||
@@ -56,11 +61,6 @@ builder.Services.AddKeyedSingleton<IContentDecoder, PdfContentDecoder>(MediaType
|
|||||||
builder.Services.AddKeyedSingleton<IContentDecoder, DocxContentDecoder>("application/vnd.openxmlformats-officedocument.wordprocessingml.document");
|
builder.Services.AddKeyedSingleton<IContentDecoder, DocxContentDecoder>("application/vnd.openxmlformats-officedocument.wordprocessingml.document");
|
||||||
builder.Services.AddKeyedSingleton<IContentDecoder, TextContentDecoder>(MediaTypeNames.Text.Plain);
|
builder.Services.AddKeyedSingleton<IContentDecoder, TextContentDecoder>(MediaTypeNames.Text.Plain);
|
||||||
|
|
||||||
builder.Services.ConfigureHttpJsonOptions(options =>
|
|
||||||
{
|
|
||||||
options.SerializerOptions.Converters.Add(new JsonStringEnumConverter());
|
|
||||||
});
|
|
||||||
|
|
||||||
builder.Services.AddOpenApi(options =>
|
builder.Services.AddOpenApi(options =>
|
||||||
{
|
{
|
||||||
options.RemoveServerList();
|
options.RemoveServerList();
|
||||||
@@ -127,9 +127,9 @@ documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchServi
|
|||||||
[Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) =>
|
[Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) =>
|
||||||
{
|
{
|
||||||
using var stream = file.OpenReadStream();
|
using var stream = file.OpenReadStream();
|
||||||
documentId = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId);
|
var response = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId);
|
||||||
|
|
||||||
return TypedResults.Ok(new UploadDocumentResponse(documentId.Value));
|
return TypedResults.Ok(response);
|
||||||
})
|
})
|
||||||
.DisableAntiforgery()
|
.DisableAntiforgery()
|
||||||
.ProducesProblem(StatusCodes.Status400BadRequest)
|
.ProducesProblem(StatusCodes.Status400BadRequest)
|
||||||
@@ -157,7 +157,7 @@ app.MapPost("/api/ask", async (Question question, VectorSearchService vectorSear
|
|||||||
app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService,
|
app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService,
|
||||||
[Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) =>
|
[Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) =>
|
||||||
{
|
{
|
||||||
async IAsyncEnumerable<Response> Stream()
|
async IAsyncEnumerable<QuestionResponse> Stream()
|
||||||
{
|
{
|
||||||
// Requests a streaming response.
|
// Requests a streaming response.
|
||||||
var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate);
|
var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate);
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ using Microsoft.Extensions.Caching.Hybrid;
|
|||||||
using Microsoft.Extensions.Options;
|
using Microsoft.Extensions.Options;
|
||||||
using Microsoft.SemanticKernel.ChatCompletion;
|
using Microsoft.SemanticKernel.ChatCompletion;
|
||||||
using Microsoft.SemanticKernel.Connectors.AzureOpenAI;
|
using Microsoft.SemanticKernel.Connectors.AzureOpenAI;
|
||||||
|
using OpenAI.Chat;
|
||||||
|
using SqlDatabaseVectorSearch.Models;
|
||||||
using SqlDatabaseVectorSearch.Settings;
|
using SqlDatabaseVectorSearch.Settings;
|
||||||
|
|
||||||
namespace SqlDatabaseVectorSearch.Services;
|
namespace SqlDatabaseVectorSearch.Services;
|
||||||
@@ -11,7 +13,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
|
|||||||
{
|
{
|
||||||
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
||||||
|
|
||||||
public async Task<string> CreateQuestionAsync(Guid conversationId, string question)
|
public async Task<ChatResponse> CreateQuestionAsync(Guid conversationId, string question)
|
||||||
{
|
{
|
||||||
var chat = await GetChatHistoryAsync(conversationId);
|
var chat = await GetChatHistoryAsync(conversationId);
|
||||||
|
|
||||||
@@ -31,10 +33,12 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
|
|||||||
|
|
||||||
await UpdateCacheAsync(conversationId, chat);
|
await UpdateCacheAsync(conversationId, chat);
|
||||||
|
|
||||||
return reformulatedQuestion.Content!;
|
var tokenUsage = GetTokenUsage(reformulatedQuestion);
|
||||||
|
|
||||||
|
return new(reformulatedQuestion.Content!, tokenUsage);
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<string> AskQuestionAsync(Guid conversationId, IEnumerable<string> chunks, string question)
|
public async Task<ChatResponse> AskQuestionAsync(Guid conversationId, IEnumerable<string> chunks, string question)
|
||||||
{
|
{
|
||||||
var chat = CreateChatAsync(chunks, question);
|
var chat = CreateChatAsync(chunks, question);
|
||||||
|
|
||||||
@@ -46,10 +50,12 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
|
|||||||
// Add question and answer to the chat history.
|
// Add question and answer to the chat history.
|
||||||
await SetChatHistoryAsync(conversationId, question, answer.Content!);
|
await SetChatHistoryAsync(conversationId, question, answer.Content!);
|
||||||
|
|
||||||
return answer.Content!;
|
var tokenUsage = GetTokenUsage(answer);
|
||||||
|
|
||||||
|
return new(answer.Content!, tokenUsage);
|
||||||
}
|
}
|
||||||
|
|
||||||
public async IAsyncEnumerable<string> AskStreamingAsync(Guid conversationId, IEnumerable<string> chunks, string question)
|
public async IAsyncEnumerable<ChatResponse> AskStreamingAsync(Guid conversationId, IEnumerable<string> chunks, string question)
|
||||||
{
|
{
|
||||||
var chat = CreateChatAsync(chunks, question);
|
var chat = CreateChatAsync(chunks, question);
|
||||||
|
|
||||||
@@ -61,15 +67,44 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
|
|||||||
{
|
{
|
||||||
if (!string.IsNullOrEmpty(token.Content))
|
if (!string.IsNullOrEmpty(token.Content))
|
||||||
{
|
{
|
||||||
yield return token.Content;
|
yield return new(token.Content);
|
||||||
answer.Append(token.Content);
|
answer.Append(token.Content);
|
||||||
}
|
}
|
||||||
|
else if (token.Content is null)
|
||||||
|
{
|
||||||
|
// Token usage is returned in the last message, when the Content is null.
|
||||||
|
var tokenUsage = GetTokenUsage(token);
|
||||||
|
if (tokenUsage is not null)
|
||||||
|
{
|
||||||
|
yield return new(null, tokenUsage);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add question and answer to the chat history.
|
// Add question and answer to the chat history.
|
||||||
await SetChatHistoryAsync(conversationId, question, answer.ToString());
|
await SetChatHistoryAsync(conversationId, question, answer.ToString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static TokenUsage? GetTokenUsage(Microsoft.SemanticKernel.ChatMessageContent message)
|
||||||
|
{
|
||||||
|
if (message.InnerContent is ChatCompletion content && content.Usage is not null)
|
||||||
|
{
|
||||||
|
return new(content.Usage.InputTokenCount, content.Usage.OutputTokenCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static TokenUsage? GetTokenUsage(Microsoft.SemanticKernel.StreamingChatMessageContent message)
|
||||||
|
{
|
||||||
|
if (message.InnerContent is StreamingChatCompletionUpdate content && content.Usage is not null)
|
||||||
|
{
|
||||||
|
return new(content.Usage.InputTokenCount, content.Usage.OutputTokenCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
private ChatHistory CreateChatAsync(IEnumerable<string> chunks, string question)
|
private ChatHistory CreateChatAsync(IEnumerable<string> chunks, string question)
|
||||||
{
|
{
|
||||||
var chat = new ChatHistory("""
|
var chat = new ChatHistory("""
|
||||||
|
|||||||
@@ -15,12 +15,15 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
|||||||
{
|
{
|
||||||
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
||||||
|
|
||||||
public async Task<Guid> ImportAsync(Stream stream, string name, string contentType, Guid? documentId)
|
public async Task<ImportDocumentResponse> ImportAsync(Stream stream, string name, string contentType, Guid? documentId)
|
||||||
{
|
{
|
||||||
// Extract the contents of the file.
|
// Extract the contents of the file.
|
||||||
var decoder = serviceProvider.GetKeyedService<IContentDecoder>(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported.");
|
var decoder = serviceProvider.GetKeyedService<IContentDecoder>(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported.");
|
||||||
var content = await decoder.DecodeAsync(stream, contentType);
|
var content = await decoder.DecodeAsync(stream, contentType);
|
||||||
|
|
||||||
|
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
|
||||||
|
var tokenCount = tokenizerService.CountTokens(content);
|
||||||
|
|
||||||
await dbContext.Database.BeginTransactionAsync();
|
await dbContext.Database.BeginTransactionAsync();
|
||||||
|
|
||||||
if (documentId.HasValue)
|
if (documentId.HasValue)
|
||||||
@@ -49,7 +52,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
|||||||
await dbContext.SaveChangesAsync();
|
await dbContext.SaveChangesAsync();
|
||||||
await dbContext.Database.CommitTransactionAsync();
|
await dbContext.Database.CommitTransactionAsync();
|
||||||
|
|
||||||
return document.Id;
|
return new(document.Id, tokenCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<IEnumerable<Document>> GetDocumentsAsync()
|
public async Task<IEnumerable<Document>> GetDocumentsAsync()
|
||||||
@@ -82,40 +85,45 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
|||||||
public Task DeleteDocumentAsync(Guid documentId)
|
public Task DeleteDocumentAsync(Guid documentId)
|
||||||
=> dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync();
|
=> dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync();
|
||||||
|
|
||||||
public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true)
|
public async Task<QuestionResponse> AskQuestionAsync(Question question, bool reformulate = true)
|
||||||
{
|
{
|
||||||
var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate);
|
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
|
||||||
|
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate);
|
||||||
|
|
||||||
var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion);
|
var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
|
||||||
return new Response(reformulatedQuestion, answer);
|
|
||||||
|
return new(reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage));
|
||||||
}
|
}
|
||||||
|
|
||||||
public async IAsyncEnumerable<Response> AskStreamingAsync(Question question, bool reformulate = true)
|
public async IAsyncEnumerable<QuestionResponse> AskStreamingAsync(Question question, bool reformulate = true)
|
||||||
{
|
{
|
||||||
var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate);
|
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
|
||||||
|
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate);
|
||||||
|
|
||||||
var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion);
|
var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
|
||||||
|
|
||||||
// The first message contains the original question.
|
// The first message contains the question and the corresponding token usage (if reformulated).
|
||||||
yield return new Response(reformulatedQuestion, null, StreamState.Start);
|
yield return new(reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null));
|
||||||
|
|
||||||
// Return each token as a partial response.
|
// Return each token as a partial response.
|
||||||
await foreach (var token in answerStream)
|
await foreach (var (token, tokenUsage) in answerStream)
|
||||||
{
|
{
|
||||||
yield return new Response(null, token, StreamState.Append);
|
// Token usage is contained in the last message.
|
||||||
|
yield return new(null, token, StreamState.Append, tokenUsage is not null ? new(null, null, tokenUsage) : null);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The last message tells the client that the stream has ended.
|
// The last message tells the client that the stream has ended.
|
||||||
yield return new Response(null, null, StreamState.End);
|
yield return new(null, null, StreamState.End);
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task<(string Question, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate = true)
|
private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate = true)
|
||||||
{
|
{
|
||||||
// Reformulate the following question taking into account the context of the chat to perform keyword search and embeddings:
|
// Reformulate the question taking into account the context of the chat to perform keyword search and embeddings.
|
||||||
var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : question.Text;
|
var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : new(question.Text);
|
||||||
|
|
||||||
// Perform Vector Search on SQL Database.
|
// Perform Vector Search on SQL Database.
|
||||||
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion);
|
var embeddingTokenCount = tokenizerService.CountTokens(reformulatedQuestion.Text!);
|
||||||
|
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!);
|
||||||
|
|
||||||
var chunks = await dbContext.DocumentChunks
|
var chunks = await dbContext.DocumentChunks
|
||||||
.OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray()))
|
.OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray()))
|
||||||
@@ -123,6 +131,6 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
|||||||
.Take(appSettings.MaxRelevantChunks)
|
.Take(appSettings.MaxRelevantChunks)
|
||||||
.ToListAsync();
|
.ToListAsync();
|
||||||
|
|
||||||
return (reformulatedQuestion, chunks);
|
return (reformulatedQuestion, embeddingTokenCount, chunks);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user