diff --git a/SqlDatabaseVectorSearch.png b/SqlDatabaseVectorSearch.png index 86c8658..2bbbbed 100644 Binary files a/SqlDatabaseVectorSearch.png and b/SqlDatabaseVectorSearch.png differ diff --git a/SqlDatabaseVectorSearch/Models/Response.cs b/SqlDatabaseVectorSearch/Models/Response.cs index 427d94e..faae25d 100644 --- a/SqlDatabaseVectorSearch/Models/Response.cs +++ b/SqlDatabaseVectorSearch/Models/Response.cs @@ -1,3 +1,4 @@ namespace SqlDatabaseVectorSearch.Models; -public record class Response(string Question, string Answer); +// Question and Asnwer can be null when using response streaming. +public record class Response(string? Question, string? Answer, StreamState? StreamState = null); \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Models/StreamState.cs b/SqlDatabaseVectorSearch/Models/StreamState.cs new file mode 100644 index 0000000..2bb25ad --- /dev/null +++ b/SqlDatabaseVectorSearch/Models/StreamState.cs @@ -0,0 +1,8 @@ +namespace SqlDatabaseVectorSearch.Models; + +public enum StreamState +{ + Start, + Append, + End +} \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Program.cs b/SqlDatabaseVectorSearch/Program.cs index 522a857..24807d8 100644 --- a/SqlDatabaseVectorSearch/Program.cs +++ b/SqlDatabaseVectorSearch/Program.cs @@ -1,4 +1,5 @@ using System.ComponentModel; +using System.Text.Json.Serialization; using Microsoft.AspNetCore.Http.HttpResults; using Microsoft.EntityFrameworkCore; using Microsoft.SemanticKernel; @@ -49,6 +50,11 @@ builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddScoped(); +builder.Services.ConfigureHttpJsonOptions(options => +{ + options.SerializerOptions.Converters.Add(new JsonStringEnumConverter()); +}); + builder.Services.AddOpenApi(options => { options.RemoveServerList(); @@ -114,7 +120,7 @@ documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchServi .DisableAntiforgery() .ProducesProblem(StatusCodes.Status400BadRequest) .WithSummary("Uploads a document") -.WithDescription("Uploads a document to SQL Database and saves its embedding using the new native Vector type. The document will be indexed and used to answer questions. Currently, only PDF files are supported."); +.WithDescription("Uploads a document to SQL Database and saves its embedding using the native VECTOR type. The document will be indexed and used to answer questions. Currently, only PDF files are supported."); documentsApiGroup.MapDelete("{documentId:guid}", async (Guid documentId, VectorSearchService vectorSearchService) => { @@ -134,4 +140,24 @@ app.MapPost("/api/ask", async (Question question, VectorSearchService vectorSear .WithDescription("The question will be reformulated taking into account the context of the chat identified by the given ConversationId.") .WithTags("Ask"); +app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService, + [Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) => +{ + async IAsyncEnumerable Stream() + { + // Requests a streaming response. + var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate); + + await foreach (var delta in responseStream) + { + yield return delta; + } + } + + return Stream(); +}) +.WithSummary("Asks a question and gets the response as streaming") +.WithDescription("The question will be reformulated taking into account the context of the chat identified by the given ConversationId.") +.WithTags("Ask"); + app.Run(); \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Services/ChatService.cs b/SqlDatabaseVectorSearch/Services/ChatService.cs index 4a109dc..ba1bee6 100644 --- a/SqlDatabaseVectorSearch/Services/ChatService.cs +++ b/SqlDatabaseVectorSearch/Services/ChatService.cs @@ -35,6 +35,42 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer } public async Task AskQuestionAsync(Guid conversationId, IEnumerable chunks, string question) + { + var chat = CreateChatAsync(chunks, question); + + var answer = await chatCompletionService.GetChatMessageContentAsync(chat, new AzureOpenAIPromptExecutionSettings + { + MaxTokens = appSettings.MaxOutputTokens + }); + + // Add question and answer to the chat history. + await SetChatHistoryAsync(conversationId, question, answer.Content!); + + return answer.Content!; + } + + public async IAsyncEnumerable AskStreamingAsync(Guid conversationId, IEnumerable chunks, string question) + { + var chat = CreateChatAsync(chunks, question); + + var answer = new StringBuilder(); + await foreach (var token in chatCompletionService.GetStreamingChatMessageContentsAsync(chat, new AzureOpenAIPromptExecutionSettings + { + MaxTokens = appSettings.MaxOutputTokens + })) + { + if (!string.IsNullOrEmpty(token.Content)) + { + yield return token.Content; + answer.Append(token.Content); + } + } + + // Add question and answer to the chat history. + await SetChatHistoryAsync(conversationId, question, answer.ToString()); + } + + private ChatHistory CreateChatAsync(IEnumerable chunks, string question) { var chat = new ChatHistory(""" You can use only the information provided in this chat to answer questions. If you don't know the answer, reply suggesting to refine the question. @@ -79,16 +115,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer } chat.AddUserMessage(prompt.ToString()); - - var answer = await chatCompletionService.GetChatMessageContentAsync(chat, new AzureOpenAIPromptExecutionSettings - { - MaxTokens = appSettings.MaxOutputTokens - }); - - // Add question and answer to the chat history. - await SetChatHistoryAsync(conversationId, question, answer.Content!); - - return answer.Content!; + return chat; } private async Task UpdateCacheAsync(Guid conversationId, ChatHistory chat) diff --git a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs index 1417ce0..62493bc 100644 --- a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs +++ b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs @@ -81,6 +81,33 @@ public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingG => dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync(); public async Task AskQuestionAsync(Question question, bool reformulate = true) + { + var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate); + + var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion); + return new Response(reformulatedQuestion, answer); + } + + public async IAsyncEnumerable AskStreamingAsync(Question question, bool reformulate = true) + { + var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate); + + var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion); + + // The first message contains the original question. + yield return new Response(reformulatedQuestion, null, StreamState.Start); + + // Return each token as a partial response. + await foreach (var token in answerStream) + { + yield return new Response(null, token, StreamState.Append); + } + + // The last message tells the client that the stream has ended. + yield return new Response(null, null, StreamState.End); + } + + private async Task<(string Question, IEnumerable Chunks)> CreateContextAsync(Question question, bool reformulate = true) { // Reformulate the following question taking into account the context of the chat to perform keyword search and embeddings: var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : question.Text; @@ -94,8 +121,7 @@ public class VectorSearchService(ApplicationDbContext dbContext, ITextEmbeddingG .Take(appSettings.MaxRelevantChunks) .ToListAsync(); - var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion); - return new Response(reformulatedQuestion, answer); + return (reformulatedQuestion, chunks); } private static Task GetContentAsync(Stream stream)