mirror of
https://github.com/marcominerva/SqlDatabaseVectorSearch.git
synced 2026-06-20 12:23:10 +00:00
Refactor response types and enhance token usage handling
This commit is contained in:
@@ -15,12 +15,15 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
||||
{
|
||||
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
||||
|
||||
public async Task<Guid> ImportAsync(Stream stream, string name, string contentType, Guid? documentId)
|
||||
public async Task<ImportDocumentResponse> ImportAsync(Stream stream, string name, string contentType, Guid? documentId)
|
||||
{
|
||||
// Extract the contents of the file.
|
||||
var decoder = serviceProvider.GetKeyedService<IContentDecoder>(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported.");
|
||||
var content = await decoder.DecodeAsync(stream, contentType);
|
||||
|
||||
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
|
||||
var tokenCount = tokenizerService.CountTokens(content);
|
||||
|
||||
await dbContext.Database.BeginTransactionAsync();
|
||||
|
||||
if (documentId.HasValue)
|
||||
@@ -49,7 +52,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
||||
await dbContext.SaveChangesAsync();
|
||||
await dbContext.Database.CommitTransactionAsync();
|
||||
|
||||
return document.Id;
|
||||
return new(document.Id, tokenCount);
|
||||
}
|
||||
|
||||
public async Task<IEnumerable<Document>> GetDocumentsAsync()
|
||||
@@ -82,40 +85,45 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
||||
public Task DeleteDocumentAsync(Guid documentId)
|
||||
=> dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync();
|
||||
|
||||
public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true)
|
||||
public async Task<QuestionResponse> AskQuestionAsync(Question question, bool reformulate = true)
|
||||
{
|
||||
var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate);
|
||||
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
|
||||
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate);
|
||||
|
||||
var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion);
|
||||
return new Response(reformulatedQuestion, answer);
|
||||
var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
|
||||
|
||||
return new(reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage));
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<Response> AskStreamingAsync(Question question, bool reformulate = true)
|
||||
public async IAsyncEnumerable<QuestionResponse> AskStreamingAsync(Question question, bool reformulate = true)
|
||||
{
|
||||
var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate);
|
||||
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
|
||||
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate);
|
||||
|
||||
var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion);
|
||||
var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
|
||||
|
||||
// The first message contains the original question.
|
||||
yield return new Response(reformulatedQuestion, null, StreamState.Start);
|
||||
// The first message contains the question and the corresponding token usage (if reformulated).
|
||||
yield return new(reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null));
|
||||
|
||||
// Return each token as a partial response.
|
||||
await foreach (var token in answerStream)
|
||||
await foreach (var (token, tokenUsage) in answerStream)
|
||||
{
|
||||
yield return new Response(null, token, StreamState.Append);
|
||||
// Token usage is contained in the last message.
|
||||
yield return new(null, token, StreamState.Append, tokenUsage is not null ? new(null, null, tokenUsage) : null);
|
||||
}
|
||||
|
||||
// The last message tells the client that the stream has ended.
|
||||
yield return new Response(null, null, StreamState.End);
|
||||
yield return new(null, null, StreamState.End);
|
||||
}
|
||||
|
||||
private async Task<(string Question, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate = true)
|
||||
private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate = true)
|
||||
{
|
||||
// Reformulate the following question taking into account the context of the chat to perform keyword search and embeddings:
|
||||
var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : question.Text;
|
||||
// Reformulate the question taking into account the context of the chat to perform keyword search and embeddings.
|
||||
var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : new(question.Text);
|
||||
|
||||
// Perform Vector Search on SQL Database.
|
||||
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion);
|
||||
var embeddingTokenCount = tokenizerService.CountTokens(reformulatedQuestion.Text!);
|
||||
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!);
|
||||
|
||||
var chunks = await dbContext.DocumentChunks
|
||||
.OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray()))
|
||||
@@ -123,6 +131,6 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
||||
.Take(appSettings.MaxRelevantChunks)
|
||||
.ToListAsync();
|
||||
|
||||
return (reformulatedQuestion, chunks);
|
||||
return (reformulatedQuestion, embeddingTokenCount, chunks);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user