From 8e069799937b35a60e4c8ec6dcf04ebe28a3de67 Mon Sep 17 00:00:00 2001 From: Marco Minerva Date: Thu, 30 Jan 2025 12:56:33 +0100 Subject: [PATCH 1/4] Refactor response types and enhance token usage handling --- .../Models/ChatResponse.cs | 3 ++ .../Models/ImportDocumentResponse.cs | 3 ++ .../Models/QuestionResponse.cs | 4 ++ SqlDatabaseVectorSearch/Models/Response.cs | 4 -- SqlDatabaseVectorSearch/Models/TokenUsage.cs | 6 +++ .../Models/TokenUsageResponse.cs | 3 ++ .../Models/UploadDocumentResponse.cs | 3 -- SqlDatabaseVectorSearch/Program.cs | 16 +++---- .../Services/ChatService.cs | 47 ++++++++++++++++--- .../Services/VectorSearchService.cs | 46 ++++++++++-------- 10 files changed, 95 insertions(+), 40 deletions(-) create mode 100644 SqlDatabaseVectorSearch/Models/ChatResponse.cs create mode 100644 SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs create mode 100644 SqlDatabaseVectorSearch/Models/QuestionResponse.cs delete mode 100644 SqlDatabaseVectorSearch/Models/Response.cs create mode 100644 SqlDatabaseVectorSearch/Models/TokenUsage.cs create mode 100644 SqlDatabaseVectorSearch/Models/TokenUsageResponse.cs delete mode 100644 SqlDatabaseVectorSearch/Models/UploadDocumentResponse.cs diff --git a/SqlDatabaseVectorSearch/Models/ChatResponse.cs b/SqlDatabaseVectorSearch/Models/ChatResponse.cs new file mode 100644 index 0000000..1f67ca9 --- /dev/null +++ b/SqlDatabaseVectorSearch/Models/ChatResponse.cs @@ -0,0 +1,3 @@ +namespace SqlDatabaseVectorSearch.Models; + +public record class ChatResponse(string? Text, TokenUsage? TokenUsage = null); \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs b/SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs new file mode 100644 index 0000000..a94991b --- /dev/null +++ b/SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs @@ -0,0 +1,3 @@ +namespace SqlDatabaseVectorSearch.Models; + +public record class ImportDocumentResponse(Guid DocumentId, int TotalTokenCount); diff --git a/SqlDatabaseVectorSearch/Models/QuestionResponse.cs b/SqlDatabaseVectorSearch/Models/QuestionResponse.cs new file mode 100644 index 0000000..8cb72b2 --- /dev/null +++ b/SqlDatabaseVectorSearch/Models/QuestionResponse.cs @@ -0,0 +1,4 @@ +namespace SqlDatabaseVectorSearch.Models; + +// Question and Asnwer can be null when using response streaming. +public record class QuestionResponse(string? Question, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null); \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Models/Response.cs b/SqlDatabaseVectorSearch/Models/Response.cs deleted file mode 100644 index faae25d..0000000 --- a/SqlDatabaseVectorSearch/Models/Response.cs +++ /dev/null @@ -1,4 +0,0 @@ -namespace SqlDatabaseVectorSearch.Models; - -// Question and Asnwer can be null when using response streaming. -public record class Response(string? Question, string? Answer, StreamState? StreamState = null); \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Models/TokenUsage.cs b/SqlDatabaseVectorSearch/Models/TokenUsage.cs new file mode 100644 index 0000000..603ee1f --- /dev/null +++ b/SqlDatabaseVectorSearch/Models/TokenUsage.cs @@ -0,0 +1,6 @@ +namespace SqlDatabaseVectorSearch.Models; + +public record class TokenUsage(int InputTokenCount, int OutputTokenCount) +{ + public int TotalTokenCount => InputTokenCount + OutputTokenCount; +} diff --git a/SqlDatabaseVectorSearch/Models/TokenUsageResponse.cs b/SqlDatabaseVectorSearch/Models/TokenUsageResponse.cs new file mode 100644 index 0000000..132ea4a --- /dev/null +++ b/SqlDatabaseVectorSearch/Models/TokenUsageResponse.cs @@ -0,0 +1,3 @@ +namespace SqlDatabaseVectorSearch.Models; + +public record class TokenUsageResponse(TokenUsage? Reformulation, int? EmbeddingTokenCount, TokenUsage? Question); diff --git a/SqlDatabaseVectorSearch/Models/UploadDocumentResponse.cs b/SqlDatabaseVectorSearch/Models/UploadDocumentResponse.cs deleted file mode 100644 index 17d65fd..0000000 --- a/SqlDatabaseVectorSearch/Models/UploadDocumentResponse.cs +++ /dev/null @@ -1,3 +0,0 @@ -namespace SqlDatabaseVectorSearch.Models; - -public record class UploadDocumentResponse(Guid DocumentId); diff --git a/SqlDatabaseVectorSearch/Program.cs b/SqlDatabaseVectorSearch/Program.cs index 4176fca..7c955a2 100644 --- a/SqlDatabaseVectorSearch/Program.cs +++ b/SqlDatabaseVectorSearch/Program.cs @@ -19,6 +19,11 @@ builder.Configuration.AddJsonFile("appsettings.local.json", optional: true, relo var aiSettings = builder.Services.ConfigureAndGet(builder.Configuration, "AzureOpenAI")!; var appSettings = builder.Services.ConfigureAndGet(builder.Configuration, nameof(AppSettings))!; +builder.Services.ConfigureHttpJsonOptions(options => +{ + options.SerializerOptions.Converters.Add(new JsonStringEnumConverter()); +}); + builder.Services.AddSingleton(TimeProvider.System); builder.Services.AddSqlServer(builder.Configuration.GetConnectionString("SqlConnection"), options => @@ -56,11 +61,6 @@ builder.Services.AddKeyedSingleton(MediaType builder.Services.AddKeyedSingleton("application/vnd.openxmlformats-officedocument.wordprocessingml.document"); builder.Services.AddKeyedSingleton(MediaTypeNames.Text.Plain); -builder.Services.ConfigureHttpJsonOptions(options => -{ - options.SerializerOptions.Converters.Add(new JsonStringEnumConverter()); -}); - builder.Services.AddOpenApi(options => { options.RemoveServerList(); @@ -127,9 +127,9 @@ documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchServi [Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) => { using var stream = file.OpenReadStream(); - documentId = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId); + var response = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId); - return TypedResults.Ok(new UploadDocumentResponse(documentId.Value)); + return TypedResults.Ok(response); }) .DisableAntiforgery() .ProducesProblem(StatusCodes.Status400BadRequest) @@ -157,7 +157,7 @@ app.MapPost("/api/ask", async (Question question, VectorSearchService vectorSear app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService, [Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) => { - async IAsyncEnumerable Stream() + async IAsyncEnumerable Stream() { // Requests a streaming response. var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate); diff --git a/SqlDatabaseVectorSearch/Services/ChatService.cs b/SqlDatabaseVectorSearch/Services/ChatService.cs index d75db78..5854eec 100644 --- a/SqlDatabaseVectorSearch/Services/ChatService.cs +++ b/SqlDatabaseVectorSearch/Services/ChatService.cs @@ -3,6 +3,8 @@ using Microsoft.Extensions.Caching.Hybrid; using Microsoft.Extensions.Options; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.AzureOpenAI; +using OpenAI.Chat; +using SqlDatabaseVectorSearch.Models; using SqlDatabaseVectorSearch.Settings; namespace SqlDatabaseVectorSearch.Services; @@ -11,7 +13,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer { private readonly AppSettings appSettings = appSettingsOptions.Value; - public async Task CreateQuestionAsync(Guid conversationId, string question) + public async Task CreateQuestionAsync(Guid conversationId, string question) { var chat = await GetChatHistoryAsync(conversationId); @@ -31,10 +33,12 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer await UpdateCacheAsync(conversationId, chat); - return reformulatedQuestion.Content!; + var tokenUsage = GetTokenUsage(reformulatedQuestion); + + return new(reformulatedQuestion.Content!, tokenUsage); } - public async Task AskQuestionAsync(Guid conversationId, IEnumerable chunks, string question) + public async Task AskQuestionAsync(Guid conversationId, IEnumerable chunks, string question) { var chat = CreateChatAsync(chunks, question); @@ -46,10 +50,12 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer // Add question and answer to the chat history. await SetChatHistoryAsync(conversationId, question, answer.Content!); - return answer.Content!; + var tokenUsage = GetTokenUsage(answer); + + return new(answer.Content!, tokenUsage); } - public async IAsyncEnumerable AskStreamingAsync(Guid conversationId, IEnumerable chunks, string question) + public async IAsyncEnumerable AskStreamingAsync(Guid conversationId, IEnumerable chunks, string question) { var chat = CreateChatAsync(chunks, question); @@ -61,15 +67,44 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer { if (!string.IsNullOrEmpty(token.Content)) { - yield return token.Content; + yield return new(token.Content); answer.Append(token.Content); } + else if (token.Content is null) + { + // Token usage is returned in the last message, when the Content is null. + var tokenUsage = GetTokenUsage(token); + if (tokenUsage is not null) + { + yield return new(null, tokenUsage); + } + } } // Add question and answer to the chat history. await SetChatHistoryAsync(conversationId, question, answer.ToString()); } + private static TokenUsage? GetTokenUsage(Microsoft.SemanticKernel.ChatMessageContent message) + { + if (message.InnerContent is ChatCompletion content && content.Usage is not null) + { + return new(content.Usage.InputTokenCount, content.Usage.OutputTokenCount); + } + + return null; + } + + private static TokenUsage? GetTokenUsage(Microsoft.SemanticKernel.StreamingChatMessageContent message) + { + if (message.InnerContent is StreamingChatCompletionUpdate content && content.Usage is not null) + { + return new(content.Usage.InputTokenCount, content.Usage.OutputTokenCount); + } + + return null; + } + private ChatHistory CreateChatAsync(IEnumerable chunks, string question) { var chat = new ChatHistory(""" diff --git a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs index 3e0eca1..ad998e6 100644 --- a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs +++ b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs @@ -15,12 +15,15 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb { private readonly AppSettings appSettings = appSettingsOptions.Value; - public async Task ImportAsync(Stream stream, string name, string contentType, Guid? documentId) + public async Task ImportAsync(Stream stream, string name, string contentType, Guid? documentId) { // Extract the contents of the file. var decoder = serviceProvider.GetKeyedService(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported."); var content = await decoder.DecodeAsync(stream, contentType); + // We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis). + var tokenCount = tokenizerService.CountTokens(content); + await dbContext.Database.BeginTransactionAsync(); if (documentId.HasValue) @@ -49,7 +52,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb await dbContext.SaveChangesAsync(); await dbContext.Database.CommitTransactionAsync(); - return document.Id; + return new(document.Id, tokenCount); } public async Task> GetDocumentsAsync() @@ -82,40 +85,45 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb public Task DeleteDocumentAsync(Guid documentId) => dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync(); - public async Task AskQuestionAsync(Question question, bool reformulate = true) + public async Task AskQuestionAsync(Question question, bool reformulate = true) { - var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate); + // It the user doesn't want to reforulate the question, CreateContextAsync returns the original one. + var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate); - var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion); - return new Response(reformulatedQuestion, answer); + var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!); + + return new(reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage)); } - public async IAsyncEnumerable AskStreamingAsync(Question question, bool reformulate = true) + public async IAsyncEnumerable AskStreamingAsync(Question question, bool reformulate = true) { - var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate); + // It the user doesn't want to reforulate the question, CreateContextAsync returns the original one. + var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate); - var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion); + var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!); - // The first message contains the original question. - yield return new Response(reformulatedQuestion, null, StreamState.Start); + // The first message contains the question and the corresponding token usage (if reformulated). + yield return new(reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null)); // Return each token as a partial response. - await foreach (var token in answerStream) + await foreach (var (token, tokenUsage) in answerStream) { - yield return new Response(null, token, StreamState.Append); + // Token usage is contained in the last message. + yield return new(null, token, StreamState.Append, tokenUsage is not null ? new(null, null, tokenUsage) : null); } // The last message tells the client that the stream has ended. - yield return new Response(null, null, StreamState.End); + yield return new(null, null, StreamState.End); } - private async Task<(string Question, IEnumerable Chunks)> CreateContextAsync(Question question, bool reformulate = true) + private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable Chunks)> CreateContextAsync(Question question, bool reformulate = true) { - // Reformulate the following question taking into account the context of the chat to perform keyword search and embeddings: - var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : question.Text; + // Reformulate the question taking into account the context of the chat to perform keyword search and embeddings. + var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : new(question.Text); // Perform Vector Search on SQL Database. - var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion); + var embeddingTokenCount = tokenizerService.CountTokens(reformulatedQuestion.Text!); + var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!); var chunks = await dbContext.DocumentChunks .OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray())) @@ -123,6 +131,6 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb .Take(appSettings.MaxRelevantChunks) .ToListAsync(); - return (reformulatedQuestion, chunks); + return (reformulatedQuestion, embeddingTokenCount, chunks); } } \ No newline at end of file From cdd0199e8f6a590330a5ec5548e3608c07d72d5b Mon Sep 17 00:00:00 2001 From: Marco Minerva Date: Fri, 7 Feb 2025 10:24:16 +0100 Subject: [PATCH 2/4] Refactor services and update token handling - Replace `TotalTokenCount` with `EmbeddingTokenCount` in `ImportDocumentResponse`. - Add `OriginalQuestion` and `ReformulatedQuestion` fields to `QuestionResponse` and a new constructor. - Add a new constructor to `TokenUsageResponse` to initialize `Question`. - Add `TextChunkerService` to service collection in `Program.cs`. - Clarify prompt and update token counting in `ChatService`. - Differentiate token counting in `TokenizerService` with `CountChatCompletionTokens` and `CountEmbeddingTokens`. - Update `VectorSearchService` to use `TextChunkerService` and new token counting methods. - Introduce `TextChunkerService` for text splitting and tokenization. --- .../Models/ImportDocumentResponse.cs | 2 +- .../Models/QuestionResponse.cs | 10 +++++-- .../Models/TokenUsageResponse.cs | 8 ++++- SqlDatabaseVectorSearch/Program.cs | 1 + .../Services/ChatService.cs | 10 +++---- .../Services/TextChunkerService.cs | 18 +++++++++++ .../Services/TokenizerService.cs | 11 +++++-- .../Services/VectorSearchService.cs | 30 +++++++++++-------- 8 files changed, 65 insertions(+), 25 deletions(-) create mode 100644 SqlDatabaseVectorSearch/Services/TextChunkerService.cs diff --git a/SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs b/SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs index a94991b..252018a 100644 --- a/SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs +++ b/SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs @@ -1,3 +1,3 @@ namespace SqlDatabaseVectorSearch.Models; -public record class ImportDocumentResponse(Guid DocumentId, int TotalTokenCount); +public record class ImportDocumentResponse(Guid DocumentId, int EmbeddingTokenCount); diff --git a/SqlDatabaseVectorSearch/Models/QuestionResponse.cs b/SqlDatabaseVectorSearch/Models/QuestionResponse.cs index 8cb72b2..958f9d1 100644 --- a/SqlDatabaseVectorSearch/Models/QuestionResponse.cs +++ b/SqlDatabaseVectorSearch/Models/QuestionResponse.cs @@ -1,4 +1,10 @@ namespace SqlDatabaseVectorSearch.Models; -// Question and Asnwer can be null when using response streaming. -public record class QuestionResponse(string? Question, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null); \ No newline at end of file +// Question and Answer can be null when using response streaming. +public record class QuestionResponse(string? OriginalQuestion, string? ReformulatedQuestion, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null) +{ + public QuestionResponse(string? token, StreamState streamState, TokenUsageResponse? tokenUsageResponse = null) + : this(null, null, token, streamState, tokenUsageResponse) + { + } +} \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Models/TokenUsageResponse.cs b/SqlDatabaseVectorSearch/Models/TokenUsageResponse.cs index 132ea4a..500a72e 100644 --- a/SqlDatabaseVectorSearch/Models/TokenUsageResponse.cs +++ b/SqlDatabaseVectorSearch/Models/TokenUsageResponse.cs @@ -1,3 +1,9 @@ namespace SqlDatabaseVectorSearch.Models; -public record class TokenUsageResponse(TokenUsage? Reformulation, int? EmbeddingTokenCount, TokenUsage? Question); +public record class TokenUsageResponse(TokenUsage? Reformulation, int? EmbeddingTokenCount, TokenUsage? Question) +{ + public TokenUsageResponse(TokenUsage? question) + : this(null, null, question) + { + } +} diff --git a/SqlDatabaseVectorSearch/Program.cs b/SqlDatabaseVectorSearch/Program.cs index 7c955a2..4d9e034 100644 --- a/SqlDatabaseVectorSearch/Program.cs +++ b/SqlDatabaseVectorSearch/Program.cs @@ -53,6 +53,7 @@ builder.Services.AddKernel() .AddAzureOpenAITextEmbeddingGeneration(aiSettings.Embedding.Deployment, aiSettings.Embedding.Endpoint, aiSettings.Embedding.ApiKey, dimensions: aiSettings.Embedding.Dimensions) .AddAzureOpenAIChatCompletion(aiSettings.ChatCompletion.Deployment, aiSettings.ChatCompletion.Endpoint, aiSettings.ChatCompletion.ApiKey); +builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddScoped(); diff --git a/SqlDatabaseVectorSearch/Services/ChatService.cs b/SqlDatabaseVectorSearch/Services/ChatService.cs index 5854eec..f2e6ede 100644 --- a/SqlDatabaseVectorSearch/Services/ChatService.cs +++ b/SqlDatabaseVectorSearch/Services/ChatService.cs @@ -22,7 +22,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer --- {question} --- - You must reformulate the question in the same language of the user's question. + You must reformulate the question in the same language of the user's question. For example, it the user asks a question in English, the answer must be in English. Never add "in this chat", "in the context of this chat", "in the context of our conversation", "search for" or something like that in your answer. """; @@ -111,7 +111,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer You can use only the information provided in this chat to answer questions. If you don't know the answer, reply suggesting to refine the question. For example, if the user asks "What is the capital of France?" and in this chat there isn't information about France, you should reply something like "This information isn't available in the given context". Never answer to questions that are not related to this chat. - You must answer in the same language of the user's question. + You must answer in the same language of the user's question. For example, it the user asks a question in English, the answer must be in English. """); var prompt = new StringBuilder($""" @@ -124,15 +124,15 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer """); var tokensAvailable = appSettings.MaxInputTokens - - tokenizerService.CountTokens(chat[0].ToString()) // System prompt. - - tokenizerService.CountTokens(prompt.ToString()) // Initial user prompt. + - tokenizerService.CountChatCompletionTokens(chat[0].ToString()) // System prompt. + - tokenizerService.CountChatCompletionTokens(prompt.ToString()) // Initial user prompt. - appSettings.MaxOutputTokens; // To ensure there is enough space for the answer. foreach (var chunk in chunks) { var text = $"---{Environment.NewLine}{chunk}"; - var tokenCount = tokenizerService.CountTokens(text); + var tokenCount = tokenizerService.CountChatCompletionTokens(text); if (tokenCount > tokensAvailable) { // There isn't enough space to add the current chunk. diff --git a/SqlDatabaseVectorSearch/Services/TextChunkerService.cs b/SqlDatabaseVectorSearch/Services/TextChunkerService.cs new file mode 100644 index 0000000..851f08d --- /dev/null +++ b/SqlDatabaseVectorSearch/Services/TextChunkerService.cs @@ -0,0 +1,18 @@ +using Microsoft.Extensions.Options; +using Microsoft.SemanticKernel.Text; +using SqlDatabaseVectorSearch.Settings; + +namespace SqlDatabaseVectorSearch.Services; + +public class TextChunkerService(TokenizerService tokenizerService, IOptions appSettingsOptions) +{ + private readonly AppSettings appSettings = appSettingsOptions.Value; + + public IList Split(string text) + { + var lines = TextChunker.SplitPlainTextLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountChatCompletionTokens); + var paragraphs = TextChunker.SplitPlainTextParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountChatCompletionTokens); + + return paragraphs; + } +} diff --git a/SqlDatabaseVectorSearch/Services/TokenizerService.cs b/SqlDatabaseVectorSearch/Services/TokenizerService.cs index 9115b3d..71a3714 100644 --- a/SqlDatabaseVectorSearch/Services/TokenizerService.cs +++ b/SqlDatabaseVectorSearch/Services/TokenizerService.cs @@ -6,8 +6,13 @@ namespace SqlDatabaseVectorSearch.Services; public class TokenizerService(IOptions settingsOptions) { - private readonly TiktokenTokenizer tokenizer = TiktokenTokenizer.CreateForModel(settingsOptions.Value.ChatCompletion.ModelId); + private readonly TiktokenTokenizer chatCompletiontokenizer = TiktokenTokenizer.CreateForModel(settingsOptions.Value.ChatCompletion.ModelId); - public int CountTokens(string input) - => tokenizer.CountTokens(input); + private readonly TiktokenTokenizer embeddingTokenizer = TiktokenTokenizer.CreateForModel(settingsOptions.Value.Embedding.ModelId); + + public int CountChatCompletionTokens(string input) + => chatCompletiontokenizer.CountTokens(input); + + public int CountEmbeddingTokens(string input) + => embeddingTokenizer.CountTokens(input); } diff --git a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs index ad998e6..2517156 100644 --- a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs +++ b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs @@ -2,7 +2,6 @@ using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Options; using Microsoft.SemanticKernel.Embeddings; -using Microsoft.SemanticKernel.Text; using SqlDatabaseVectorSearch.ContentDecoders; using SqlDatabaseVectorSearch.DataAccessLayer; using SqlDatabaseVectorSearch.Models; @@ -11,7 +10,7 @@ using Entities = SqlDatabaseVectorSearch.DataAccessLayer.Entities; namespace SqlDatabaseVectorSearch.Services; -public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDbContext dbContext, ITextEmbeddingGenerationService textEmbeddingGenerationService, ChatService chatService, TokenizerService tokenizerService, TimeProvider timeProvider, IOptions appSettingsOptions, ILogger logger) +public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDbContext dbContext, ITextEmbeddingGenerationService textEmbeddingGenerationService, TokenizerService tokenizerService, TextChunkerService textChunkerService, ChatService chatService, TimeProvider timeProvider, IOptions appSettingsOptions, ILogger logger) { private readonly AppSettings appSettings = appSettingsOptions.Value; @@ -22,7 +21,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb var content = await decoder.DecodeAsync(stream, contentType); // We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis). - var tokenCount = tokenizerService.CountTokens(content); + var tokenCount = tokenizerService.CountEmbeddingTokens(content); await dbContext.Database.BeginTransactionAsync(); @@ -36,14 +35,13 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb dbContext.Documents.Add(document); // Split the content into chunks and generate the embeddings for each one. - var lines = TextChunker.SplitPlainTextLines(content, appSettings.MaxTokensPerLine, tokenizerService.CountTokens); - var paragraphs = TextChunker.SplitPlainTextParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountTokens); + var paragraphs = textChunkerService.Split(content); var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(paragraphs); // Save the document chunks and the corresponding embedding in the database. foreach (var (index, paragraph) in paragraphs.Index()) { - logger.LogInformation("Storing a paragraph of {TokenCount} tokens.", tokenizerService.CountTokens(paragraph)); + logger.LogInformation("Storing a paragraph of {TokenCount} tokens.", tokenizerService.CountChatCompletionTokens(paragraph)); var documentChunk = new Entities.DocumentChunk { Document = document, Index = index, Content = paragraph!, Embedding = embeddings[index].ToArray() }; dbContext.DocumentChunks.Add(documentChunk); @@ -92,7 +90,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!); - return new(reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage)); + return new(question.Text, reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage)); } public async IAsyncEnumerable AskStreamingAsync(Question question, bool reformulate = true) @@ -103,26 +101,32 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!); // The first message contains the question and the corresponding token usage (if reformulated). - yield return new(reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null)); + yield return new(question.Text, reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null)); + + TokenUsageResponse? tokenUsageResponse = null; // Return each token as a partial response. await foreach (var (token, tokenUsage) in answerStream) { - // Token usage is contained in the last message. - yield return new(null, token, StreamState.Append, tokenUsage is not null ? new(null, null, tokenUsage) : null); + // Token usage is expected in the last message. + tokenUsageResponse = tokenUsage is not null ? new(tokenUsage) : null; + yield return new(token, tokenUsageResponse is null ? StreamState.Append : StreamState.End, tokenUsageResponse); } - // The last message tells the client that the stream has ended. - yield return new(null, null, StreamState.End); + // If the token usage has not been returned in the last message, we must explicitly tells that the stream is ended. + if (tokenUsageResponse is null) + { + yield return new(null, StreamState.End); + } } private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable Chunks)> CreateContextAsync(Question question, bool reformulate = true) { // Reformulate the question taking into account the context of the chat to perform keyword search and embeddings. var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : new(question.Text); + var embeddingTokenCount = tokenizerService.CountEmbeddingTokens(reformulatedQuestion.Text!); // Perform Vector Search on SQL Database. - var embeddingTokenCount = tokenizerService.CountTokens(reformulatedQuestion.Text!); var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!); var chunks = await dbContext.DocumentChunks From f02a1c9b69023f698bb64f8b53d3ffc9ee1147d0 Mon Sep 17 00:00:00 2001 From: Marco Minerva Date: Fri, 7 Feb 2025 11:30:14 +0100 Subject: [PATCH 3/4] Refactor document operations into DocumentService Refactored Program.cs to use AddAzureSql with new options. Added VectorSearchService and DocumentService as scoped services. Updated documentsApiGroup to use DocumentService for document operations and added a delete document endpoint. Moved document-related methods from VectorSearchService to new DocumentService for better separation of concerns. --- SqlDatabaseVectorSearch/Program.cs | 32 ++++---- .../Services/DocumentService.cs | 39 ++++++++++ .../Services/VectorSearchService.cs | 78 +++++++------------ 3 files changed, 83 insertions(+), 66 deletions(-) create mode 100644 SqlDatabaseVectorSearch/Services/DocumentService.cs diff --git a/SqlDatabaseVectorSearch/Program.cs b/SqlDatabaseVectorSearch/Program.cs index 4d9e034..33bde71 100644 --- a/SqlDatabaseVectorSearch/Program.cs +++ b/SqlDatabaseVectorSearch/Program.cs @@ -26,7 +26,7 @@ builder.Services.ConfigureHttpJsonOptions(options => builder.Services.AddSingleton(TimeProvider.System); -builder.Services.AddSqlServer(builder.Configuration.GetConnectionString("SqlConnection"), options => +builder.Services.AddAzureSql(builder.Configuration.GetConnectionString("SqlConnection"), options => { options.UseVectorSearch(); }, options => @@ -56,7 +56,9 @@ builder.Services.AddKernel() builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); + builder.Services.AddScoped(); +builder.Services.AddScoped(); builder.Services.AddKeyedSingleton(MediaTypeNames.Application.Pdf); builder.Services.AddKeyedSingleton("application/vnd.openxmlformats-officedocument.wordprocessingml.document"); @@ -96,24 +98,24 @@ app.UseSwaggerUI(options => var documentsApiGroup = app.MapGroup("/api/documents").WithTags("Documents"); -documentsApiGroup.MapGet(string.Empty, async (VectorSearchService vectorSearchService) => +documentsApiGroup.MapGet(string.Empty, async (DocumentService documentService) => { - var documents = await vectorSearchService.GetDocumentsAsync(); + var documents = await documentService.GetDocumentsAsync(); return TypedResults.Ok(documents); }) .WithSummary("Gets the list of documents"); -documentsApiGroup.MapGet("{documentId:guid}/chunks", async (Guid documentId, VectorSearchService vectorSearchService) => +documentsApiGroup.MapGet("{documentId:guid}/chunks", async (Guid documentId, DocumentService documentService) => { - var documents = await vectorSearchService.GetDocumentChunksAsync(documentId); + var documents = await documentService.GetDocumentChunksAsync(documentId); return TypedResults.Ok(documents); }) .WithSummary("Gets the list of chunks of a given document") .WithDescription("The list does not contain embedding. Use '/api/documents/{documentId}/chunks/{documentChunkId}' to get the embedding for a given chunk."); -documentsApiGroup.MapGet("{documentId:guid}/chunks/{documentChunkId:guid}", async Task, NotFound>> (Guid documentId, Guid documentChunkId, VectorSearchService vectorSearchService) => +documentsApiGroup.MapGet("{documentId:guid}/chunks/{documentChunkId:guid}", async Task, NotFound>> (Guid documentId, Guid documentChunkId, DocumentService documentService) => { - var chunk = await vectorSearchService.GetDocumentChunkEmbeddingAsync(documentId, documentChunkId); + var chunk = await documentService.GetDocumentChunkEmbeddingAsync(documentId, documentChunkId); if (chunk is null) { return TypedResults.NotFound(); @@ -124,6 +126,14 @@ documentsApiGroup.MapGet("{documentId:guid}/chunks/{documentChunkId:guid}", asyn .ProducesProblem(StatusCodes.Status404NotFound) .WithSummary("Gets the details of a given chunk, includings its embedding"); +documentsApiGroup.MapDelete("{documentId:guid}", async (Guid documentId, DocumentService documentService) => +{ + await documentService.DeleteDocumentAsync(documentId); + return TypedResults.NoContent(); +}) +.WithSummary("Deletes a document") +.WithDescription("This endpoint deletes the document and all its chunks."); + documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchService vectorSearchService, [Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) => { @@ -137,14 +147,6 @@ documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchServi .WithSummary("Uploads a document") .WithDescription("Uploads a document to SQL Database and saves its embedding using the native VECTOR type. The document will be indexed and used to answer questions. Currently, PDF, DOCX and TXT files are supported."); -documentsApiGroup.MapDelete("{documentId:guid}", async (Guid documentId, VectorSearchService vectorSearchService) => -{ - await vectorSearchService.DeleteDocumentAsync(documentId); - return TypedResults.NoContent(); -}) -.WithSummary("Deletes a document") -.WithDescription("This endpoint deletes the document and all its chunks."); - app.MapPost("/api/ask", async (Question question, VectorSearchService vectorSearchService, [Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) => { diff --git a/SqlDatabaseVectorSearch/Services/DocumentService.cs b/SqlDatabaseVectorSearch/Services/DocumentService.cs new file mode 100644 index 0000000..b6a7ed5 --- /dev/null +++ b/SqlDatabaseVectorSearch/Services/DocumentService.cs @@ -0,0 +1,39 @@ +using System.Data; +using Microsoft.EntityFrameworkCore; +using SqlDatabaseVectorSearch.DataAccessLayer; +using SqlDatabaseVectorSearch.Models; + +namespace SqlDatabaseVectorSearch.Services; + +public class DocumentService(ApplicationDbContext dbContext) +{ + public async Task> GetDocumentsAsync() + { + var documents = await dbContext.Documents.OrderBy(d => d.Name) + .Select(d => new Document(d.Id, d.Name, d.CreationDate, d.Chunks.Count)) + .ToListAsync(); + + return documents; + } + + public async Task> GetDocumentChunksAsync(Guid documentId) + { + var documentChunks = await dbContext.DocumentChunks.Where(c => c.DocumentId == documentId).OrderBy(c => c.Index) + .Select(c => new DocumentChunk(c.Id, c.Index, c.Content, null)) + .ToListAsync(); + + return documentChunks; + } + + public async Task GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId) + { + var documentChunk = await dbContext.DocumentChunks.Where(c => c.Id == documentChunkId && c.DocumentId == documentId) + .Select(c => new DocumentChunk(c.Id, c.Index, c.Content, c.Embedding)) + .FirstOrDefaultAsync(); + + return documentChunk; + } + + public Task DeleteDocumentAsync(Guid documentId) + => dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync(); +} \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs index 2517156..74aa741 100644 --- a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs +++ b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs @@ -10,7 +10,7 @@ using Entities = SqlDatabaseVectorSearch.DataAccessLayer.Entities; namespace SqlDatabaseVectorSearch.Services; -public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDbContext dbContext, ITextEmbeddingGenerationService textEmbeddingGenerationService, TokenizerService tokenizerService, TextChunkerService textChunkerService, ChatService chatService, TimeProvider timeProvider, IOptions appSettingsOptions, ILogger logger) +public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDbContext dbContext, DocumentService documentService, ITextEmbeddingGenerationService textEmbeddingGenerationService, TokenizerService tokenizerService, TextChunkerService textChunkerService, ChatService chatService, TimeProvider timeProvider, IOptions appSettingsOptions, ILogger logger) { private readonly AppSettings appSettings = appSettingsOptions.Value; @@ -23,66 +23,42 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb // We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis). var tokenCount = tokenizerService.CountEmbeddingTokens(content); - await dbContext.Database.BeginTransactionAsync(); - - if (documentId.HasValue) + var strategy = dbContext.Database.CreateExecutionStrategy(); + var document = await strategy.ExecuteAsync(async () => { - // If the user is importing a document that already exists, delete the previous one. - await DeleteDocumentAsync(documentId.Value); - } + await dbContext.Database.BeginTransactionAsync(); - var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() }; - dbContext.Documents.Add(document); + if (documentId.HasValue) + { + // If the user is importing a document that already exists, delete the previous one. + await documentService.DeleteDocumentAsync(documentId.Value); + } - // Split the content into chunks and generate the embeddings for each one. - var paragraphs = textChunkerService.Split(content); - var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(paragraphs); + var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() }; + dbContext.Documents.Add(document); - // Save the document chunks and the corresponding embedding in the database. - foreach (var (index, paragraph) in paragraphs.Index()) - { - logger.LogInformation("Storing a paragraph of {TokenCount} tokens.", tokenizerService.CountChatCompletionTokens(paragraph)); + // Split the content into chunks and generate the embeddings for each one. + var paragraphs = textChunkerService.Split(content); + var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(paragraphs); - var documentChunk = new Entities.DocumentChunk { Document = document, Index = index, Content = paragraph!, Embedding = embeddings[index].ToArray() }; - dbContext.DocumentChunks.Add(documentChunk); - } + // Save the document chunks and the corresponding embedding in the database. + foreach (var (index, paragraph) in paragraphs.Index()) + { + logger.LogInformation("Storing a paragraph of {TokenCount} tokens.", tokenizerService.CountChatCompletionTokens(paragraph)); - await dbContext.SaveChangesAsync(); - await dbContext.Database.CommitTransactionAsync(); + var documentChunk = new Entities.DocumentChunk { Document = document, Index = index, Content = paragraph!, Embedding = embeddings[index].ToArray() }; + dbContext.DocumentChunks.Add(documentChunk); + } + + await dbContext.SaveChangesAsync(); + await dbContext.Database.CommitTransactionAsync(); + + return document; + }); return new(document.Id, tokenCount); } - public async Task> GetDocumentsAsync() - { - var documents = await dbContext.Documents.OrderBy(d => d.Name) - .Select(d => new Document(d.Id, d.Name, d.CreationDate, d.Chunks.Count)) - .ToListAsync(); - - return documents; - } - - public async Task> GetDocumentChunksAsync(Guid documentId) - { - var documentChunks = await dbContext.DocumentChunks.Where(c => c.DocumentId == documentId).OrderBy(c => c.Index) - .Select(c => new DocumentChunk(c.Id, c.Index, c.Content, null)) - .ToListAsync(); - - return documentChunks; - } - - public async Task GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId) - { - var documentChunk = await dbContext.DocumentChunks.Where(c => c.Id == documentChunkId && c.DocumentId == documentId) - .Select(c => new DocumentChunk(c.Id, c.Index, c.Content, c.Embedding)) - .FirstOrDefaultAsync(); - - return documentChunk; - } - - public Task DeleteDocumentAsync(Guid documentId) - => dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync(); - public async Task AskQuestionAsync(Question question, bool reformulate = true) { // It the user doesn't want to reforulate the question, CreateContextAsync returns the original one. From 9312bf35cb448e6b5899cdec5fa315f589cb6e1e Mon Sep 17 00:00:00 2001 From: Marco Minerva Date: Fri, 7 Feb 2025 12:02:48 +0100 Subject: [PATCH 4/4] Add new endpoints and update README with details Updated README.md to document new `/api/documents`, `/api/ask`, and `/api/ask-streaming` endpoints, including features like conversation history, token usage, and response streaming. Enhanced Program.cs by adding new endpoints for asking questions and streaming responses, with additional metadata. Updated `documentsApiGroup` to include new document management endpoints. --- README.md | 181 ++++++++++++++++++++++++++++- SqlDatabaseVectorSearch/Program.cs | 60 +++++----- 2 files changed, 209 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index ee762b6..f953290 100644 --- a/README.md +++ b/README.md @@ -8,11 +8,188 @@ The application is a Minimal API that exposes endpoints to load documents, gener ![SQL Database Vector Search](https://github.com/marcominerva/SqlDatabaseVectorSearch/blob/master/SqlDatabaseVectorSearch.png) -### Setup +## Setup - [Create an Azure SQL Database](https://learn.microsoft.com/en-us/azure/azure-sql/database/single-database-create-quickstart) on a server that has the Vector Support feature enabled - Execute the [Scripts.sql](https://github.com/marcominerva/SqlDatabaseVectorSearch/blob/master/Scripts.sql) file to create the tables needed by the application - You may need to update the size of the [`VECTOR`](https://github.com/marcominerva/SqlDatabaseVectorSearch/blob/master/Scripts.sql#L17) column to match the size of the embedding model. Currently, the maximum allowed value is 1998. - Open the [appsettings.json](https://github.com/marcominerva/SqlDatabaseVectorSearch/blob/master/SqlDatabaseVectorSearch/appsettings.json) file and set the connection string to the database and the other settings required by Azure OpenAI - If your embedding model supports shortening, like **text-embedding-3-small** and **text-embedding-3-large**, and you want to use this feature, you need to set the [`Dimensions`](https://github.com/marcominerva/SqlDatabaseVectorSearch/blob/master/SqlDatabaseVectorSearch/appsettings.json#L17) property to match the value you have used in the SQL script. If your model doesn't provide this feature, or do you want to use the default size, just leave the [`Dimensions`](https://github.com/marcominerva/SqlDatabaseVectorSearch/blob/master/SqlDatabaseVectorSearch/appsettings.json#L17) property to NULL. Keep in mind that **text-embedding-3-small** has a dimension of 1536, while **text-embedding-3-large** uses vectors with 3072 elements, so with this latter model it is mandatory to specify a value (that, as said, must be less or equal to 1998). -- Run the application and start importing your documents. +- Run the application and start importing your documents with `/api/documents` endpoint. +- Ask questions using `/api/ask` or `/api/ask-streaming` endpoints. + +## Supported features + +- Conversation history with question reformulation +- Information about token usage +- Response streaming + +```json +{ + "originalQuestion": "why is mars called the red planet?", + "reformulatedQuestion": "Why is Mars referred to as the Red Planet?", + "answer": "Mars is referred to as the Red Planet due to its characteristic reddish color, which is caused by the abundance of iron oxide (rust) on its surface. This distinctive coloration has also been a significant factor in the cultural and mythological associations of Mars across different civilizations.", + "streamState": null, + "tokenUsage": { + "reformulation": { + "inputTokenCount": 107, + "outputTokenCount": 10, + "totalTokenCount": 117 + }, + "embeddingTokenCount": 10, + "question": { + "inputTokenCount": 9142, + "outputTokenCount": 53, + "totalTokenCount": 9195 + } + } +} +``` + +### How response streaming works + +When using the `/api/ask-streaming` endpoint, answers will be streamed as happens with the typical response from OpenAI. The format of the response is the following: + +```json +[ + { + "originalQuestion": "why is mars called the red planet?", + "reformulatedQuestion": "Why is Mars referred to as the Red Planet?", + "answer": null, + "streamState": "Start", + "tokenUsage": { + "reformulation": { + "inputTokenCount": 107, + "outputTokenCount": 10, + "totalTokenCount": 117 + }, + "embeddingTokenCount": 10, + "question": null + } + }, + { + "originalQuestion": null, + "reformulatedQuestion": null, + "answer": "Mars", + "streamState": "Append", + "tokenUsage": null + }, + { + "originalQuestion": null, + "reformulatedQuestion": null, + "answer": " is", + "streamState": "Append", + "tokenUsage": null + }, + { + "originalQuestion": null, + "reformulatedQuestion": null, + "answer": " called", + "streamState": "Append", + "tokenUsage": null + }, + { + "originalQuestion": null, + "reformulatedQuestion": null, + "answer": " the", + "streamState": "Append", + "tokenUsage": null + }, + { + "originalQuestion": null, + "reformulatedQuestion": null, + "answer": " Red", + "streamState": "Append", + "tokenUsage": null + }, + { + "originalQuestion": null, + "reformulatedQuestion": null, + "answer": " Planet", + "streamState": "Append", + "tokenUsage": null + }, + { + "originalQuestion": null, + "reformulatedQuestion": null, + "answer": " because", + "streamState": "Append", + "tokenUsage": null + }, + { + "originalQuestion": null, + "reformulatedQuestion": null, + "answer": " of", + "streamState": "Append", + "tokenUsage": null + }, + { + "originalQuestion": null, + "reformulatedQuestion": null, + "answer": " its", + "streamState": "Append", + "tokenUsage": null + }, + { + "originalQuestion": null, + "reformulatedQuestion": null, + "answer": " characteristic", + "streamState": "Append", + "tokenUsage": null + }, + { + "originalQuestion": null, + "reformulatedQuestion": null, + "answer": " red", + "streamState": "Append", + "tokenUsage": null + }, + { + "originalQuestion": null, + "reformulatedQuestion": null, + "answer": " color", + "streamState": "Append", + "tokenUsage": null + }, + { + "originalQuestion": null, + "reformulatedQuestion": null, + "answer": ",", + "streamState": "Append", + "tokenUsage": null + }, + //... + { + "originalQuestion": null, + "reformulatedQuestion": null, + "answer": ".", + "streamState": "Append", + "tokenUsage": null + }, + { + "originalQuestion": null, + "reformulatedQuestion": null, + "answer": null, + "streamState": "End", + "tokenUsage": { + "reformulation": null, + "embeddingTokenCount": null, + "question": { + "inputTokenCount": 8986, + "outputTokenCount": 31, + "totalTokenCount": 9017 + } + } + } +] +``` + +- The first piece of the response has the following characteristics: + - the *streamState* tag is set to `Start`, + - it contains the question and its reformulation (if not requested, *reformulatedQuestion* will be equals to *originalQuestion*) + - the *tokenUsage* section holds information about token used for reformulation (if done) and for the embedding of the question +- Then, there are as many elements for the actual answer as necessary: + - each one contains a token + - The *streamState* property is set to `Append` + - *origianlQuestion*, *reformulatedQuestion* and *tokenUsage* are always `null` +- The stream ends when an element with *streamState* equals to `End` is received. This element contains token usage information for the question and the whole answer. \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Program.cs b/SqlDatabaseVectorSearch/Program.cs index 33bde71..8a6c94b 100644 --- a/SqlDatabaseVectorSearch/Program.cs +++ b/SqlDatabaseVectorSearch/Program.cs @@ -96,6 +96,36 @@ app.UseSwaggerUI(options => options.SwaggerEndpoint("/openapi/v1.json", builder.Environment.ApplicationName); }); +app.MapPost("/api/ask", async (Question question, VectorSearchService vectorSearchService, + [Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) => +{ + var response = await vectorSearchService.AskQuestionAsync(question, reformulate); + return TypedResults.Ok(response); +}) +.WithSummary("Asks a question") +.WithDescription("The question will be reformulated taking into account the context of the chat identified by the given ConversationId.") +.WithTags("Ask"); + +app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService, + [Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) => +{ + async IAsyncEnumerable Stream() + { + // Requests a streaming response. + var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate); + + await foreach (var delta in responseStream) + { + yield return delta; + } + } + + return Stream(); +}) +.WithSummary("Asks a question and gets the response as streaming") +.WithDescription("The question will be reformulated taking into account the context of the chat identified by the given ConversationId.") +.WithTags("Ask"); + var documentsApiGroup = app.MapGroup("/api/documents").WithTags("Documents"); documentsApiGroup.MapGet(string.Empty, async (DocumentService documentService) => @@ -147,34 +177,4 @@ documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchServi .WithSummary("Uploads a document") .WithDescription("Uploads a document to SQL Database and saves its embedding using the native VECTOR type. The document will be indexed and used to answer questions. Currently, PDF, DOCX and TXT files are supported."); -app.MapPost("/api/ask", async (Question question, VectorSearchService vectorSearchService, - [Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) => -{ - var response = await vectorSearchService.AskQuestionAsync(question, reformulate); - return TypedResults.Ok(response); -}) -.WithSummary("Asks a question") -.WithDescription("The question will be reformulated taking into account the context of the chat identified by the given ConversationId.") -.WithTags("Ask"); - -app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService, - [Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) => -{ - async IAsyncEnumerable Stream() - { - // Requests a streaming response. - var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate); - - await foreach (var delta in responseStream) - { - yield return delta; - } - } - - return Stream(); -}) -.WithSummary("Asks a question and gets the response as streaming") -.WithDescription("The question will be reformulated taking into account the context of the chat identified by the given ConversationId.") -.WithTags("Ask"); - app.Run(); \ No newline at end of file