diff --git a/SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs b/SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs index a94991b..252018a 100644 --- a/SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs +++ b/SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs @@ -1,3 +1,3 @@ namespace SqlDatabaseVectorSearch.Models; -public record class ImportDocumentResponse(Guid DocumentId, int TotalTokenCount); +public record class ImportDocumentResponse(Guid DocumentId, int EmbeddingTokenCount); diff --git a/SqlDatabaseVectorSearch/Models/QuestionResponse.cs b/SqlDatabaseVectorSearch/Models/QuestionResponse.cs index 8cb72b2..958f9d1 100644 --- a/SqlDatabaseVectorSearch/Models/QuestionResponse.cs +++ b/SqlDatabaseVectorSearch/Models/QuestionResponse.cs @@ -1,4 +1,10 @@ namespace SqlDatabaseVectorSearch.Models; -// Question and Asnwer can be null when using response streaming. -public record class QuestionResponse(string? Question, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null); \ No newline at end of file +// Question and Answer can be null when using response streaming. +public record class QuestionResponse(string? OriginalQuestion, string? ReformulatedQuestion, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null) +{ + public QuestionResponse(string? token, StreamState streamState, TokenUsageResponse? tokenUsageResponse = null) + : this(null, null, token, streamState, tokenUsageResponse) + { + } +} \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Models/TokenUsageResponse.cs b/SqlDatabaseVectorSearch/Models/TokenUsageResponse.cs index 132ea4a..500a72e 100644 --- a/SqlDatabaseVectorSearch/Models/TokenUsageResponse.cs +++ b/SqlDatabaseVectorSearch/Models/TokenUsageResponse.cs @@ -1,3 +1,9 @@ namespace SqlDatabaseVectorSearch.Models; -public record class TokenUsageResponse(TokenUsage? Reformulation, int? EmbeddingTokenCount, TokenUsage? Question); +public record class TokenUsageResponse(TokenUsage? Reformulation, int? EmbeddingTokenCount, TokenUsage? Question) +{ + public TokenUsageResponse(TokenUsage? question) + : this(null, null, question) + { + } +} diff --git a/SqlDatabaseVectorSearch/Program.cs b/SqlDatabaseVectorSearch/Program.cs index 7c955a2..4d9e034 100644 --- a/SqlDatabaseVectorSearch/Program.cs +++ b/SqlDatabaseVectorSearch/Program.cs @@ -53,6 +53,7 @@ builder.Services.AddKernel() .AddAzureOpenAITextEmbeddingGeneration(aiSettings.Embedding.Deployment, aiSettings.Embedding.Endpoint, aiSettings.Embedding.ApiKey, dimensions: aiSettings.Embedding.Dimensions) .AddAzureOpenAIChatCompletion(aiSettings.ChatCompletion.Deployment, aiSettings.ChatCompletion.Endpoint, aiSettings.ChatCompletion.ApiKey); +builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddScoped(); diff --git a/SqlDatabaseVectorSearch/Services/ChatService.cs b/SqlDatabaseVectorSearch/Services/ChatService.cs index 5854eec..f2e6ede 100644 --- a/SqlDatabaseVectorSearch/Services/ChatService.cs +++ b/SqlDatabaseVectorSearch/Services/ChatService.cs @@ -22,7 +22,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer --- {question} --- - You must reformulate the question in the same language of the user's question. + You must reformulate the question in the same language of the user's question. For example, it the user asks a question in English, the answer must be in English. Never add "in this chat", "in the context of this chat", "in the context of our conversation", "search for" or something like that in your answer. """; @@ -111,7 +111,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer You can use only the information provided in this chat to answer questions. If you don't know the answer, reply suggesting to refine the question. For example, if the user asks "What is the capital of France?" and in this chat there isn't information about France, you should reply something like "This information isn't available in the given context". Never answer to questions that are not related to this chat. - You must answer in the same language of the user's question. + You must answer in the same language of the user's question. For example, it the user asks a question in English, the answer must be in English. """); var prompt = new StringBuilder($""" @@ -124,15 +124,15 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer """); var tokensAvailable = appSettings.MaxInputTokens - - tokenizerService.CountTokens(chat[0].ToString()) // System prompt. - - tokenizerService.CountTokens(prompt.ToString()) // Initial user prompt. + - tokenizerService.CountChatCompletionTokens(chat[0].ToString()) // System prompt. + - tokenizerService.CountChatCompletionTokens(prompt.ToString()) // Initial user prompt. - appSettings.MaxOutputTokens; // To ensure there is enough space for the answer. foreach (var chunk in chunks) { var text = $"---{Environment.NewLine}{chunk}"; - var tokenCount = tokenizerService.CountTokens(text); + var tokenCount = tokenizerService.CountChatCompletionTokens(text); if (tokenCount > tokensAvailable) { // There isn't enough space to add the current chunk. diff --git a/SqlDatabaseVectorSearch/Services/TextChunkerService.cs b/SqlDatabaseVectorSearch/Services/TextChunkerService.cs new file mode 100644 index 0000000..851f08d --- /dev/null +++ b/SqlDatabaseVectorSearch/Services/TextChunkerService.cs @@ -0,0 +1,18 @@ +using Microsoft.Extensions.Options; +using Microsoft.SemanticKernel.Text; +using SqlDatabaseVectorSearch.Settings; + +namespace SqlDatabaseVectorSearch.Services; + +public class TextChunkerService(TokenizerService tokenizerService, IOptions appSettingsOptions) +{ + private readonly AppSettings appSettings = appSettingsOptions.Value; + + public IList Split(string text) + { + var lines = TextChunker.SplitPlainTextLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountChatCompletionTokens); + var paragraphs = TextChunker.SplitPlainTextParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountChatCompletionTokens); + + return paragraphs; + } +} diff --git a/SqlDatabaseVectorSearch/Services/TokenizerService.cs b/SqlDatabaseVectorSearch/Services/TokenizerService.cs index 9115b3d..71a3714 100644 --- a/SqlDatabaseVectorSearch/Services/TokenizerService.cs +++ b/SqlDatabaseVectorSearch/Services/TokenizerService.cs @@ -6,8 +6,13 @@ namespace SqlDatabaseVectorSearch.Services; public class TokenizerService(IOptions settingsOptions) { - private readonly TiktokenTokenizer tokenizer = TiktokenTokenizer.CreateForModel(settingsOptions.Value.ChatCompletion.ModelId); + private readonly TiktokenTokenizer chatCompletiontokenizer = TiktokenTokenizer.CreateForModel(settingsOptions.Value.ChatCompletion.ModelId); - public int CountTokens(string input) - => tokenizer.CountTokens(input); + private readonly TiktokenTokenizer embeddingTokenizer = TiktokenTokenizer.CreateForModel(settingsOptions.Value.Embedding.ModelId); + + public int CountChatCompletionTokens(string input) + => chatCompletiontokenizer.CountTokens(input); + + public int CountEmbeddingTokens(string input) + => embeddingTokenizer.CountTokens(input); } diff --git a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs index ad998e6..2517156 100644 --- a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs +++ b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs @@ -2,7 +2,6 @@ using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Options; using Microsoft.SemanticKernel.Embeddings; -using Microsoft.SemanticKernel.Text; using SqlDatabaseVectorSearch.ContentDecoders; using SqlDatabaseVectorSearch.DataAccessLayer; using SqlDatabaseVectorSearch.Models; @@ -11,7 +10,7 @@ using Entities = SqlDatabaseVectorSearch.DataAccessLayer.Entities; namespace SqlDatabaseVectorSearch.Services; -public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDbContext dbContext, ITextEmbeddingGenerationService textEmbeddingGenerationService, ChatService chatService, TokenizerService tokenizerService, TimeProvider timeProvider, IOptions appSettingsOptions, ILogger logger) +public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDbContext dbContext, ITextEmbeddingGenerationService textEmbeddingGenerationService, TokenizerService tokenizerService, TextChunkerService textChunkerService, ChatService chatService, TimeProvider timeProvider, IOptions appSettingsOptions, ILogger logger) { private readonly AppSettings appSettings = appSettingsOptions.Value; @@ -22,7 +21,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb var content = await decoder.DecodeAsync(stream, contentType); // We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis). - var tokenCount = tokenizerService.CountTokens(content); + var tokenCount = tokenizerService.CountEmbeddingTokens(content); await dbContext.Database.BeginTransactionAsync(); @@ -36,14 +35,13 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb dbContext.Documents.Add(document); // Split the content into chunks and generate the embeddings for each one. - var lines = TextChunker.SplitPlainTextLines(content, appSettings.MaxTokensPerLine, tokenizerService.CountTokens); - var paragraphs = TextChunker.SplitPlainTextParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountTokens); + var paragraphs = textChunkerService.Split(content); var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(paragraphs); // Save the document chunks and the corresponding embedding in the database. foreach (var (index, paragraph) in paragraphs.Index()) { - logger.LogInformation("Storing a paragraph of {TokenCount} tokens.", tokenizerService.CountTokens(paragraph)); + logger.LogInformation("Storing a paragraph of {TokenCount} tokens.", tokenizerService.CountChatCompletionTokens(paragraph)); var documentChunk = new Entities.DocumentChunk { Document = document, Index = index, Content = paragraph!, Embedding = embeddings[index].ToArray() }; dbContext.DocumentChunks.Add(documentChunk); @@ -92,7 +90,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!); - return new(reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage)); + return new(question.Text, reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage)); } public async IAsyncEnumerable AskStreamingAsync(Question question, bool reformulate = true) @@ -103,26 +101,32 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!); // The first message contains the question and the corresponding token usage (if reformulated). - yield return new(reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null)); + yield return new(question.Text, reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null)); + + TokenUsageResponse? tokenUsageResponse = null; // Return each token as a partial response. await foreach (var (token, tokenUsage) in answerStream) { - // Token usage is contained in the last message. - yield return new(null, token, StreamState.Append, tokenUsage is not null ? new(null, null, tokenUsage) : null); + // Token usage is expected in the last message. + tokenUsageResponse = tokenUsage is not null ? new(tokenUsage) : null; + yield return new(token, tokenUsageResponse is null ? StreamState.Append : StreamState.End, tokenUsageResponse); } - // The last message tells the client that the stream has ended. - yield return new(null, null, StreamState.End); + // If the token usage has not been returned in the last message, we must explicitly tells that the stream is ended. + if (tokenUsageResponse is null) + { + yield return new(null, StreamState.End); + } } private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable Chunks)> CreateContextAsync(Question question, bool reformulate = true) { // Reformulate the question taking into account the context of the chat to perform keyword search and embeddings. var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : new(question.Text); + var embeddingTokenCount = tokenizerService.CountEmbeddingTokens(reformulatedQuestion.Text!); // Perform Vector Search on SQL Database. - var embeddingTokenCount = tokenizerService.CountTokens(reformulatedQuestion.Text!); var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!); var chunks = await dbContext.DocumentChunks