diff --git a/SqlDatabaseVectorSearch/Models/ChatResponse.cs b/SqlDatabaseVectorSearch/Models/ChatResponse.cs new file mode 100644 index 0000000..1f67ca9 --- /dev/null +++ b/SqlDatabaseVectorSearch/Models/ChatResponse.cs @@ -0,0 +1,3 @@ +namespace SqlDatabaseVectorSearch.Models; + +public record class ChatResponse(string? Text, TokenUsage? TokenUsage = null); \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs b/SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs new file mode 100644 index 0000000..a94991b --- /dev/null +++ b/SqlDatabaseVectorSearch/Models/ImportDocumentResponse.cs @@ -0,0 +1,3 @@ +namespace SqlDatabaseVectorSearch.Models; + +public record class ImportDocumentResponse(Guid DocumentId, int TotalTokenCount); diff --git a/SqlDatabaseVectorSearch/Models/QuestionResponse.cs b/SqlDatabaseVectorSearch/Models/QuestionResponse.cs new file mode 100644 index 0000000..8cb72b2 --- /dev/null +++ b/SqlDatabaseVectorSearch/Models/QuestionResponse.cs @@ -0,0 +1,4 @@ +namespace SqlDatabaseVectorSearch.Models; + +// Question and Asnwer can be null when using response streaming. +public record class QuestionResponse(string? Question, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null); \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Models/Response.cs b/SqlDatabaseVectorSearch/Models/Response.cs deleted file mode 100644 index faae25d..0000000 --- a/SqlDatabaseVectorSearch/Models/Response.cs +++ /dev/null @@ -1,4 +0,0 @@ -namespace SqlDatabaseVectorSearch.Models; - -// Question and Asnwer can be null when using response streaming. -public record class Response(string? Question, string? Answer, StreamState? StreamState = null); \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Models/TokenUsage.cs b/SqlDatabaseVectorSearch/Models/TokenUsage.cs new file mode 100644 index 0000000..603ee1f --- /dev/null +++ b/SqlDatabaseVectorSearch/Models/TokenUsage.cs @@ -0,0 +1,6 @@ +namespace SqlDatabaseVectorSearch.Models; + +public record class TokenUsage(int InputTokenCount, int OutputTokenCount) +{ + public int TotalTokenCount => InputTokenCount + OutputTokenCount; +} diff --git a/SqlDatabaseVectorSearch/Models/TokenUsageResponse.cs b/SqlDatabaseVectorSearch/Models/TokenUsageResponse.cs new file mode 100644 index 0000000..132ea4a --- /dev/null +++ b/SqlDatabaseVectorSearch/Models/TokenUsageResponse.cs @@ -0,0 +1,3 @@ +namespace SqlDatabaseVectorSearch.Models; + +public record class TokenUsageResponse(TokenUsage? Reformulation, int? EmbeddingTokenCount, TokenUsage? Question); diff --git a/SqlDatabaseVectorSearch/Models/UploadDocumentResponse.cs b/SqlDatabaseVectorSearch/Models/UploadDocumentResponse.cs deleted file mode 100644 index 17d65fd..0000000 --- a/SqlDatabaseVectorSearch/Models/UploadDocumentResponse.cs +++ /dev/null @@ -1,3 +0,0 @@ -namespace SqlDatabaseVectorSearch.Models; - -public record class UploadDocumentResponse(Guid DocumentId); diff --git a/SqlDatabaseVectorSearch/Program.cs b/SqlDatabaseVectorSearch/Program.cs index 4176fca..7c955a2 100644 --- a/SqlDatabaseVectorSearch/Program.cs +++ b/SqlDatabaseVectorSearch/Program.cs @@ -19,6 +19,11 @@ builder.Configuration.AddJsonFile("appsettings.local.json", optional: true, relo var aiSettings = builder.Services.ConfigureAndGet(builder.Configuration, "AzureOpenAI")!; var appSettings = builder.Services.ConfigureAndGet(builder.Configuration, nameof(AppSettings))!; +builder.Services.ConfigureHttpJsonOptions(options => +{ + options.SerializerOptions.Converters.Add(new JsonStringEnumConverter()); +}); + builder.Services.AddSingleton(TimeProvider.System); builder.Services.AddSqlServer(builder.Configuration.GetConnectionString("SqlConnection"), options => @@ -56,11 +61,6 @@ builder.Services.AddKeyedSingleton(MediaType builder.Services.AddKeyedSingleton("application/vnd.openxmlformats-officedocument.wordprocessingml.document"); builder.Services.AddKeyedSingleton(MediaTypeNames.Text.Plain); -builder.Services.ConfigureHttpJsonOptions(options => -{ - options.SerializerOptions.Converters.Add(new JsonStringEnumConverter()); -}); - builder.Services.AddOpenApi(options => { options.RemoveServerList(); @@ -127,9 +127,9 @@ documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchServi [Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) => { using var stream = file.OpenReadStream(); - documentId = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId); + var response = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId); - return TypedResults.Ok(new UploadDocumentResponse(documentId.Value)); + return TypedResults.Ok(response); }) .DisableAntiforgery() .ProducesProblem(StatusCodes.Status400BadRequest) @@ -157,7 +157,7 @@ app.MapPost("/api/ask", async (Question question, VectorSearchService vectorSear app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService, [Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) => { - async IAsyncEnumerable Stream() + async IAsyncEnumerable Stream() { // Requests a streaming response. var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate); diff --git a/SqlDatabaseVectorSearch/Services/ChatService.cs b/SqlDatabaseVectorSearch/Services/ChatService.cs index d75db78..5854eec 100644 --- a/SqlDatabaseVectorSearch/Services/ChatService.cs +++ b/SqlDatabaseVectorSearch/Services/ChatService.cs @@ -3,6 +3,8 @@ using Microsoft.Extensions.Caching.Hybrid; using Microsoft.Extensions.Options; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.AzureOpenAI; +using OpenAI.Chat; +using SqlDatabaseVectorSearch.Models; using SqlDatabaseVectorSearch.Settings; namespace SqlDatabaseVectorSearch.Services; @@ -11,7 +13,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer { private readonly AppSettings appSettings = appSettingsOptions.Value; - public async Task CreateQuestionAsync(Guid conversationId, string question) + public async Task CreateQuestionAsync(Guid conversationId, string question) { var chat = await GetChatHistoryAsync(conversationId); @@ -31,10 +33,12 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer await UpdateCacheAsync(conversationId, chat); - return reformulatedQuestion.Content!; + var tokenUsage = GetTokenUsage(reformulatedQuestion); + + return new(reformulatedQuestion.Content!, tokenUsage); } - public async Task AskQuestionAsync(Guid conversationId, IEnumerable chunks, string question) + public async Task AskQuestionAsync(Guid conversationId, IEnumerable chunks, string question) { var chat = CreateChatAsync(chunks, question); @@ -46,10 +50,12 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer // Add question and answer to the chat history. await SetChatHistoryAsync(conversationId, question, answer.Content!); - return answer.Content!; + var tokenUsage = GetTokenUsage(answer); + + return new(answer.Content!, tokenUsage); } - public async IAsyncEnumerable AskStreamingAsync(Guid conversationId, IEnumerable chunks, string question) + public async IAsyncEnumerable AskStreamingAsync(Guid conversationId, IEnumerable chunks, string question) { var chat = CreateChatAsync(chunks, question); @@ -61,15 +67,44 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer { if (!string.IsNullOrEmpty(token.Content)) { - yield return token.Content; + yield return new(token.Content); answer.Append(token.Content); } + else if (token.Content is null) + { + // Token usage is returned in the last message, when the Content is null. + var tokenUsage = GetTokenUsage(token); + if (tokenUsage is not null) + { + yield return new(null, tokenUsage); + } + } } // Add question and answer to the chat history. await SetChatHistoryAsync(conversationId, question, answer.ToString()); } + private static TokenUsage? GetTokenUsage(Microsoft.SemanticKernel.ChatMessageContent message) + { + if (message.InnerContent is ChatCompletion content && content.Usage is not null) + { + return new(content.Usage.InputTokenCount, content.Usage.OutputTokenCount); + } + + return null; + } + + private static TokenUsage? GetTokenUsage(Microsoft.SemanticKernel.StreamingChatMessageContent message) + { + if (message.InnerContent is StreamingChatCompletionUpdate content && content.Usage is not null) + { + return new(content.Usage.InputTokenCount, content.Usage.OutputTokenCount); + } + + return null; + } + private ChatHistory CreateChatAsync(IEnumerable chunks, string question) { var chat = new ChatHistory(""" diff --git a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs index 3e0eca1..ad998e6 100644 --- a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs +++ b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs @@ -15,12 +15,15 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb { private readonly AppSettings appSettings = appSettingsOptions.Value; - public async Task ImportAsync(Stream stream, string name, string contentType, Guid? documentId) + public async Task ImportAsync(Stream stream, string name, string contentType, Guid? documentId) { // Extract the contents of the file. var decoder = serviceProvider.GetKeyedService(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported."); var content = await decoder.DecodeAsync(stream, contentType); + // We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis). + var tokenCount = tokenizerService.CountTokens(content); + await dbContext.Database.BeginTransactionAsync(); if (documentId.HasValue) @@ -49,7 +52,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb await dbContext.SaveChangesAsync(); await dbContext.Database.CommitTransactionAsync(); - return document.Id; + return new(document.Id, tokenCount); } public async Task> GetDocumentsAsync() @@ -82,40 +85,45 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb public Task DeleteDocumentAsync(Guid documentId) => dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync(); - public async Task AskQuestionAsync(Question question, bool reformulate = true) + public async Task AskQuestionAsync(Question question, bool reformulate = true) { - var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate); + // It the user doesn't want to reforulate the question, CreateContextAsync returns the original one. + var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate); - var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion); - return new Response(reformulatedQuestion, answer); + var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!); + + return new(reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage)); } - public async IAsyncEnumerable AskStreamingAsync(Question question, bool reformulate = true) + public async IAsyncEnumerable AskStreamingAsync(Question question, bool reformulate = true) { - var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate); + // It the user doesn't want to reforulate the question, CreateContextAsync returns the original one. + var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate); - var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion); + var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!); - // The first message contains the original question. - yield return new Response(reformulatedQuestion, null, StreamState.Start); + // The first message contains the question and the corresponding token usage (if reformulated). + yield return new(reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null)); // Return each token as a partial response. - await foreach (var token in answerStream) + await foreach (var (token, tokenUsage) in answerStream) { - yield return new Response(null, token, StreamState.Append); + // Token usage is contained in the last message. + yield return new(null, token, StreamState.Append, tokenUsage is not null ? new(null, null, tokenUsage) : null); } // The last message tells the client that the stream has ended. - yield return new Response(null, null, StreamState.End); + yield return new(null, null, StreamState.End); } - private async Task<(string Question, IEnumerable Chunks)> CreateContextAsync(Question question, bool reformulate = true) + private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable Chunks)> CreateContextAsync(Question question, bool reformulate = true) { - // Reformulate the following question taking into account the context of the chat to perform keyword search and embeddings: - var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : question.Text; + // Reformulate the question taking into account the context of the chat to perform keyword search and embeddings. + var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : new(question.Text); // Perform Vector Search on SQL Database. - var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion); + var embeddingTokenCount = tokenizerService.CountTokens(reformulatedQuestion.Text!); + var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!); var chunks = await dbContext.DocumentChunks .OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray())) @@ -123,6 +131,6 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb .Take(appSettings.MaxRelevantChunks) .ToListAsync(); - return (reformulatedQuestion, chunks); + return (reformulatedQuestion, embeddingTokenCount, chunks); } } \ No newline at end of file