Refactor services and update token handling

- Replace `TotalTokenCount` with `EmbeddingTokenCount` in `ImportDocumentResponse`.
- Add `OriginalQuestion` and `ReformulatedQuestion` fields to `QuestionResponse` and a new constructor.
- Add a new constructor to `TokenUsageResponse` to initialize `Question`.
- Add `TextChunkerService` to service collection in `Program.cs`.
- Clarify prompt and update token counting in `ChatService`.
- Differentiate token counting in `TokenizerService` with `CountChatCompletionTokens` and `CountEmbeddingTokens`.
- Update `VectorSearchService` to use `TextChunkerService` and new token counting methods.
- Introduce `TextChunkerService` for text splitting and tokenization.
This commit is contained in:
Marco Minerva
2025-02-07 10:24:16 +01:00
parent 8e06979993
commit cdd0199e8f
8 changed files with 65 additions and 25 deletions
@@ -2,7 +2,6 @@
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel.Embeddings;
using Microsoft.SemanticKernel.Text;
using SqlDatabaseVectorSearch.ContentDecoders;
using SqlDatabaseVectorSearch.DataAccessLayer;
using SqlDatabaseVectorSearch.Models;
@@ -11,7 +10,7 @@ using Entities = SqlDatabaseVectorSearch.DataAccessLayer.Entities;
namespace SqlDatabaseVectorSearch.Services;
public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDbContext dbContext, ITextEmbeddingGenerationService textEmbeddingGenerationService, ChatService chatService, TokenizerService tokenizerService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions, ILogger<VectorSearchService> logger)
public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDbContext dbContext, ITextEmbeddingGenerationService textEmbeddingGenerationService, TokenizerService tokenizerService, TextChunkerService textChunkerService, ChatService chatService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions, ILogger<VectorSearchService> logger)
{
private readonly AppSettings appSettings = appSettingsOptions.Value;
@@ -22,7 +21,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
var content = await decoder.DecodeAsync(stream, contentType);
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
var tokenCount = tokenizerService.CountTokens(content);
var tokenCount = tokenizerService.CountEmbeddingTokens(content);
await dbContext.Database.BeginTransactionAsync();
@@ -36,14 +35,13 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
dbContext.Documents.Add(document);
// Split the content into chunks and generate the embeddings for each one.
var lines = TextChunker.SplitPlainTextLines(content, appSettings.MaxTokensPerLine, tokenizerService.CountTokens);
var paragraphs = TextChunker.SplitPlainTextParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountTokens);
var paragraphs = textChunkerService.Split(content);
var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(paragraphs);
// Save the document chunks and the corresponding embedding in the database.
foreach (var (index, paragraph) in paragraphs.Index())
{
logger.LogInformation("Storing a paragraph of {TokenCount} tokens.", tokenizerService.CountTokens(paragraph));
logger.LogInformation("Storing a paragraph of {TokenCount} tokens.", tokenizerService.CountChatCompletionTokens(paragraph));
var documentChunk = new Entities.DocumentChunk { Document = document, Index = index, Content = paragraph!, Embedding = embeddings[index].ToArray() };
dbContext.DocumentChunks.Add(documentChunk);
@@ -92,7 +90,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
return new(reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage));
return new(question.Text, reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage));
}
public async IAsyncEnumerable<QuestionResponse> AskStreamingAsync(Question question, bool reformulate = true)
@@ -103,26 +101,32 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
// The first message contains the question and the corresponding token usage (if reformulated).
yield return new(reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null));
yield return new(question.Text, reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null));
TokenUsageResponse? tokenUsageResponse = null;
// Return each token as a partial response.
await foreach (var (token, tokenUsage) in answerStream)
{
// Token usage is contained in the last message.
yield return new(null, token, StreamState.Append, tokenUsage is not null ? new(null, null, tokenUsage) : null);
// Token usage is expected in the last message.
tokenUsageResponse = tokenUsage is not null ? new(tokenUsage) : null;
yield return new(token, tokenUsageResponse is null ? StreamState.Append : StreamState.End, tokenUsageResponse);
}
// The last message tells the client that the stream has ended.
yield return new(null, null, StreamState.End);
// If the token usage has not been returned in the last message, we must explicitly tells that the stream is ended.
if (tokenUsageResponse is null)
{
yield return new(null, StreamState.End);
}
}
private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate = true)
{
// Reformulate the question taking into account the context of the chat to perform keyword search and embeddings.
var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : new(question.Text);
var embeddingTokenCount = tokenizerService.CountEmbeddingTokens(reformulatedQuestion.Text!);
// Perform Vector Search on SQL Database.
var embeddingTokenCount = tokenizerService.CountTokens(reformulatedQuestion.Text!);
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!);
var chunks = await dbContext.DocumentChunks