From 44c6193674837a4b181cfd80985a8ca35d103c19 Mon Sep 17 00:00:00 2001 From: Marco Minerva Date: Tue, 28 Jan 2025 10:14:47 +0100 Subject: [PATCH] Add streaming and refactor chat/question handling Updated `Response` record in `Response.cs` to include an optional `StreamState` property, which can be `Start`, `Append`, or `End`. Added a new `StreamState` enum to `Response.cs`. In `ChatService.cs`, added new methods `AskQuestionAsync` and `AskStreamingAsync` to handle asking questions and streaming responses, respectively. Refactored `CreateChatAsync` to return a `ChatHistory` object. In `VectorSearchService.cs`, added a new `AskQuestionAsync` method to handle questions using `ChatService`. Updated `CreateContextAsync` to return a tuple with the reformulated question and chunks. Removed the previous implementation of `AskQuestionAsync` and replaced it with the new method utilizing `ChatService`. --- SqlDatabaseVectorSearch/Models/Response.cs | 9 +++- .../Services/ChatService.cs | 47 +++++++++++++++---- .../Services/VectorSearchService.cs | 11 ++++- 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/SqlDatabaseVectorSearch/Models/Response.cs b/SqlDatabaseVectorSearch/Models/Response.cs index 427d94e..b2e5c5e 100644 --- a/SqlDatabaseVectorSearch/Models/Response.cs +++ b/SqlDatabaseVectorSearch/Models/Response.cs @@ -1,3 +1,10 @@ namespace SqlDatabaseVectorSearch.Models; -public record class Response(string Question, string Answer); +public record class Response(string Question, string Answer, StreamState? StreamState = null); + +public enum StreamState +{ + Start, + Append, + End +} \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Services/ChatService.cs b/SqlDatabaseVectorSearch/Services/ChatService.cs index 4a109dc..ba1bee6 100644 --- a/SqlDatabaseVectorSearch/Services/ChatService.cs +++ b/SqlDatabaseVectorSearch/Services/ChatService.cs @@ -35,6 +35,42 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer } public async Task AskQuestionAsync(Guid conversationId, IEnumerable chunks, string question) + { + var chat = CreateChatAsync(chunks, question); + + var answer = await chatCompletionService.GetChatMessageContentAsync(chat, new AzureOpenAIPromptExecutionSettings + { + MaxTokens = appSettings.MaxOutputTokens + }); + + // Add question and answer to the chat history. + await SetChatHistoryAsync(conversationId, question, answer.Content!); + + return answer.Content!; + } + + public async IAsyncEnumerable AskStreamingAsync(Guid conversationId, IEnumerable chunks, string question) + { + var chat = CreateChatAsync(chunks, question); + + var answer = new StringBuilder(); + await foreach (var token in chatCompletionService.GetStreamingChatMessageContentsAsync(chat, new AzureOpenAIPromptExecutionSettings + { + MaxTokens = appSettings.MaxOutputTokens + })) + { + if (!string.IsNullOrEmpty(token.Content)) + { + yield return token.Content; + answer.Append(token.Content); + } + } + + // Add question and answer to the chat history. + await SetChatHistoryAsync(conversationId, question, answer.ToString()); + } + + private ChatHistory CreateChatAsync(IEnumerable chunks, string question) { var chat = new ChatHistory(""" You can use only the information provided in this chat to answer questions. If you don't know the answer, reply suggesting to refine the question. @@ -79,16 +115,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer } chat.AddUserMessage(prompt.ToString()); - - var answer = await chatCompletionService.GetChatMessageContentAsync(chat, new AzureOpenAIPromptExecutionSettings - { - MaxTokens = appSettings.MaxOutputTokens - }); - - // Add question and answer to the chat history. - await SetChatHistoryAsync(conversationId, question, answer.Content!); - - return answer.Content!; + return chat; } private async Task UpdateCacheAsync(Guid conversationId, ChatHistory chat) diff --git a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs index 1417ce0..53ce117 100644 --- a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs +++ b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs @@ -81,6 +81,14 @@ public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingG => dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync(); public async Task AskQuestionAsync(Question question, bool reformulate = true) + { + var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate); + + var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion); + return new Response(reformulatedQuestion, answer); + } + + private async Task<(string Question, IEnumerable Chunks)> CreateContextAsync(Question question, bool reformulate = true) { // Reformulate the following question taking into account the context of the chat to perform keyword search and embeddings: var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : question.Text; @@ -94,8 +102,7 @@ public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingG .Take(appSettings.MaxRelevantChunks) .ToListAsync(); - var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion); - return new Response(reformulatedQuestion, answer); + return (reformulatedQuestion, chunks); } private static Task GetContentAsync(Stream stream)