mirror of
https://github.com/marcominerva/SqlDatabaseVectorSearch.git
synced 2026-06-20 12:23:10 +00:00
Refactor services and update token handling
- Replace `TotalTokenCount` with `EmbeddingTokenCount` in `ImportDocumentResponse`. - Add `OriginalQuestion` and `ReformulatedQuestion` fields to `QuestionResponse` and a new constructor. - Add a new constructor to `TokenUsageResponse` to initialize `Question`. - Add `TextChunkerService` to service collection in `Program.cs`. - Clarify prompt and update token counting in `ChatService`. - Differentiate token counting in `TokenizerService` with `CountChatCompletionTokens` and `CountEmbeddingTokens`. - Update `VectorSearchService` to use `TextChunkerService` and new token counting methods. - Introduce `TextChunkerService` for text splitting and tokenization.
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
namespace SqlDatabaseVectorSearch.Models;
|
||||
|
||||
public record class ImportDocumentResponse(Guid DocumentId, int TotalTokenCount);
|
||||
public record class ImportDocumentResponse(Guid DocumentId, int EmbeddingTokenCount);
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
namespace SqlDatabaseVectorSearch.Models;
|
||||
|
||||
// Question and Asnwer can be null when using response streaming.
|
||||
public record class QuestionResponse(string? Question, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null);
|
||||
// Question and Answer can be null when using response streaming.
|
||||
public record class QuestionResponse(string? OriginalQuestion, string? ReformulatedQuestion, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null)
|
||||
{
|
||||
public QuestionResponse(string? token, StreamState streamState, TokenUsageResponse? tokenUsageResponse = null)
|
||||
: this(null, null, token, streamState, tokenUsageResponse)
|
||||
{
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,9 @@
|
||||
namespace SqlDatabaseVectorSearch.Models;
|
||||
|
||||
public record class TokenUsageResponse(TokenUsage? Reformulation, int? EmbeddingTokenCount, TokenUsage? Question);
|
||||
public record class TokenUsageResponse(TokenUsage? Reformulation, int? EmbeddingTokenCount, TokenUsage? Question)
|
||||
{
|
||||
public TokenUsageResponse(TokenUsage? question)
|
||||
: this(null, null, question)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,6 +53,7 @@ builder.Services.AddKernel()
|
||||
.AddAzureOpenAITextEmbeddingGeneration(aiSettings.Embedding.Deployment, aiSettings.Embedding.Endpoint, aiSettings.Embedding.ApiKey, dimensions: aiSettings.Embedding.Dimensions)
|
||||
.AddAzureOpenAIChatCompletion(aiSettings.ChatCompletion.Deployment, aiSettings.ChatCompletion.Endpoint, aiSettings.ChatCompletion.ApiKey);
|
||||
|
||||
builder.Services.AddSingleton<TextChunkerService>();
|
||||
builder.Services.AddSingleton<TokenizerService>();
|
||||
builder.Services.AddSingleton<ChatService>();
|
||||
builder.Services.AddScoped<VectorSearchService>();
|
||||
|
||||
@@ -22,7 +22,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
|
||||
---
|
||||
{question}
|
||||
---
|
||||
You must reformulate the question in the same language of the user's question.
|
||||
You must reformulate the question in the same language of the user's question. For example, it the user asks a question in English, the answer must be in English.
|
||||
Never add "in this chat", "in the context of this chat", "in the context of our conversation", "search for" or something like that in your answer.
|
||||
""";
|
||||
|
||||
@@ -111,7 +111,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
|
||||
You can use only the information provided in this chat to answer questions. If you don't know the answer, reply suggesting to refine the question.
|
||||
For example, if the user asks "What is the capital of France?" and in this chat there isn't information about France, you should reply something like "This information isn't available in the given context".
|
||||
Never answer to questions that are not related to this chat.
|
||||
You must answer in the same language of the user's question.
|
||||
You must answer in the same language of the user's question. For example, it the user asks a question in English, the answer must be in English.
|
||||
""");
|
||||
|
||||
var prompt = new StringBuilder($"""
|
||||
@@ -124,15 +124,15 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
|
||||
""");
|
||||
|
||||
var tokensAvailable = appSettings.MaxInputTokens
|
||||
- tokenizerService.CountTokens(chat[0].ToString()) // System prompt.
|
||||
- tokenizerService.CountTokens(prompt.ToString()) // Initial user prompt.
|
||||
- tokenizerService.CountChatCompletionTokens(chat[0].ToString()) // System prompt.
|
||||
- tokenizerService.CountChatCompletionTokens(prompt.ToString()) // Initial user prompt.
|
||||
- appSettings.MaxOutputTokens; // To ensure there is enough space for the answer.
|
||||
|
||||
foreach (var chunk in chunks)
|
||||
{
|
||||
var text = $"---{Environment.NewLine}{chunk}";
|
||||
|
||||
var tokenCount = tokenizerService.CountTokens(text);
|
||||
var tokenCount = tokenizerService.CountChatCompletionTokens(text);
|
||||
if (tokenCount > tokensAvailable)
|
||||
{
|
||||
// There isn't enough space to add the current chunk.
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
using Microsoft.Extensions.Options;
|
||||
using Microsoft.SemanticKernel.Text;
|
||||
using SqlDatabaseVectorSearch.Settings;
|
||||
|
||||
namespace SqlDatabaseVectorSearch.Services;
|
||||
|
||||
public class TextChunkerService(TokenizerService tokenizerService, IOptions<AppSettings> appSettingsOptions)
|
||||
{
|
||||
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
||||
|
||||
public IList<string> Split(string text)
|
||||
{
|
||||
var lines = TextChunker.SplitPlainTextLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountChatCompletionTokens);
|
||||
var paragraphs = TextChunker.SplitPlainTextParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountChatCompletionTokens);
|
||||
|
||||
return paragraphs;
|
||||
}
|
||||
}
|
||||
@@ -6,8 +6,13 @@ namespace SqlDatabaseVectorSearch.Services;
|
||||
|
||||
public class TokenizerService(IOptions<AzureOpenAISettings> settingsOptions)
|
||||
{
|
||||
private readonly TiktokenTokenizer tokenizer = TiktokenTokenizer.CreateForModel(settingsOptions.Value.ChatCompletion.ModelId);
|
||||
private readonly TiktokenTokenizer chatCompletiontokenizer = TiktokenTokenizer.CreateForModel(settingsOptions.Value.ChatCompletion.ModelId);
|
||||
|
||||
public int CountTokens(string input)
|
||||
=> tokenizer.CountTokens(input);
|
||||
private readonly TiktokenTokenizer embeddingTokenizer = TiktokenTokenizer.CreateForModel(settingsOptions.Value.Embedding.ModelId);
|
||||
|
||||
public int CountChatCompletionTokens(string input)
|
||||
=> chatCompletiontokenizer.CountTokens(input);
|
||||
|
||||
public int CountEmbeddingTokens(string input)
|
||||
=> embeddingTokenizer.CountTokens(input);
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
using Microsoft.EntityFrameworkCore;
|
||||
using Microsoft.Extensions.Options;
|
||||
using Microsoft.SemanticKernel.Embeddings;
|
||||
using Microsoft.SemanticKernel.Text;
|
||||
using SqlDatabaseVectorSearch.ContentDecoders;
|
||||
using SqlDatabaseVectorSearch.DataAccessLayer;
|
||||
using SqlDatabaseVectorSearch.Models;
|
||||
@@ -11,7 +10,7 @@ using Entities = SqlDatabaseVectorSearch.DataAccessLayer.Entities;
|
||||
|
||||
namespace SqlDatabaseVectorSearch.Services;
|
||||
|
||||
public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDbContext dbContext, ITextEmbeddingGenerationService textEmbeddingGenerationService, ChatService chatService, TokenizerService tokenizerService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions, ILogger<VectorSearchService> logger)
|
||||
public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDbContext dbContext, ITextEmbeddingGenerationService textEmbeddingGenerationService, TokenizerService tokenizerService, TextChunkerService textChunkerService, ChatService chatService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions, ILogger<VectorSearchService> logger)
|
||||
{
|
||||
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
||||
|
||||
@@ -22,7 +21,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
||||
var content = await decoder.DecodeAsync(stream, contentType);
|
||||
|
||||
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
|
||||
var tokenCount = tokenizerService.CountTokens(content);
|
||||
var tokenCount = tokenizerService.CountEmbeddingTokens(content);
|
||||
|
||||
await dbContext.Database.BeginTransactionAsync();
|
||||
|
||||
@@ -36,14 +35,13 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
||||
dbContext.Documents.Add(document);
|
||||
|
||||
// Split the content into chunks and generate the embeddings for each one.
|
||||
var lines = TextChunker.SplitPlainTextLines(content, appSettings.MaxTokensPerLine, tokenizerService.CountTokens);
|
||||
var paragraphs = TextChunker.SplitPlainTextParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountTokens);
|
||||
var paragraphs = textChunkerService.Split(content);
|
||||
var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(paragraphs);
|
||||
|
||||
// Save the document chunks and the corresponding embedding in the database.
|
||||
foreach (var (index, paragraph) in paragraphs.Index())
|
||||
{
|
||||
logger.LogInformation("Storing a paragraph of {TokenCount} tokens.", tokenizerService.CountTokens(paragraph));
|
||||
logger.LogInformation("Storing a paragraph of {TokenCount} tokens.", tokenizerService.CountChatCompletionTokens(paragraph));
|
||||
|
||||
var documentChunk = new Entities.DocumentChunk { Document = document, Index = index, Content = paragraph!, Embedding = embeddings[index].ToArray() };
|
||||
dbContext.DocumentChunks.Add(documentChunk);
|
||||
@@ -92,7 +90,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
||||
|
||||
var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
|
||||
|
||||
return new(reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage));
|
||||
return new(question.Text, reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage));
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<QuestionResponse> AskStreamingAsync(Question question, bool reformulate = true)
|
||||
@@ -103,26 +101,32 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
||||
var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
|
||||
|
||||
// The first message contains the question and the corresponding token usage (if reformulated).
|
||||
yield return new(reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null));
|
||||
yield return new(question.Text, reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null));
|
||||
|
||||
TokenUsageResponse? tokenUsageResponse = null;
|
||||
|
||||
// Return each token as a partial response.
|
||||
await foreach (var (token, tokenUsage) in answerStream)
|
||||
{
|
||||
// Token usage is contained in the last message.
|
||||
yield return new(null, token, StreamState.Append, tokenUsage is not null ? new(null, null, tokenUsage) : null);
|
||||
// Token usage is expected in the last message.
|
||||
tokenUsageResponse = tokenUsage is not null ? new(tokenUsage) : null;
|
||||
yield return new(token, tokenUsageResponse is null ? StreamState.Append : StreamState.End, tokenUsageResponse);
|
||||
}
|
||||
|
||||
// The last message tells the client that the stream has ended.
|
||||
yield return new(null, null, StreamState.End);
|
||||
// If the token usage has not been returned in the last message, we must explicitly tells that the stream is ended.
|
||||
if (tokenUsageResponse is null)
|
||||
{
|
||||
yield return new(null, StreamState.End);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate = true)
|
||||
{
|
||||
// Reformulate the question taking into account the context of the chat to perform keyword search and embeddings.
|
||||
var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : new(question.Text);
|
||||
var embeddingTokenCount = tokenizerService.CountEmbeddingTokens(reformulatedQuestion.Text!);
|
||||
|
||||
// Perform Vector Search on SQL Database.
|
||||
var embeddingTokenCount = tokenizerService.CountTokens(reformulatedQuestion.Text!);
|
||||
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!);
|
||||
|
||||
var chunks = await dbContext.DocumentChunks
|
||||
|
||||
Reference in New Issue
Block a user