diff --git a/SqlDatabaseVectorSearch/ContentDecoders/DocxContentDecoder.cs b/SqlDatabaseVectorSearch/ContentDecoders/DocxContentDecoder.cs index 36b6d06..b67a45b 100644 --- a/SqlDatabaseVectorSearch/ContentDecoders/DocxContentDecoder.cs +++ b/SqlDatabaseVectorSearch/ContentDecoders/DocxContentDecoder.cs @@ -6,7 +6,7 @@ namespace SqlDatabaseVectorSearch.ContentDecoders; public class DocxContentDecoder : IContentDecoder { - public Task DecodeAsync(Stream stream, string contentType) + public Task DecodeAsync(Stream stream, string contentType, CancellationToken cancellationToken = default) { // Open a Word document for read-only access. using var document = WordprocessingDocument.Open(stream, false); diff --git a/SqlDatabaseVectorSearch/ContentDecoders/IContentDecoder.cs b/SqlDatabaseVectorSearch/ContentDecoders/IContentDecoder.cs index 87a736f..c5a46b5 100644 --- a/SqlDatabaseVectorSearch/ContentDecoders/IContentDecoder.cs +++ b/SqlDatabaseVectorSearch/ContentDecoders/IContentDecoder.cs @@ -2,5 +2,5 @@ public interface IContentDecoder { - Task DecodeAsync(Stream stream, string contentType); + Task DecodeAsync(Stream stream, string contentType, CancellationToken cancellationToken = default); } diff --git a/SqlDatabaseVectorSearch/ContentDecoders/PdfContentDecoder.cs b/SqlDatabaseVectorSearch/ContentDecoders/PdfContentDecoder.cs index ecb14a5..60710f7 100644 --- a/SqlDatabaseVectorSearch/ContentDecoders/PdfContentDecoder.cs +++ b/SqlDatabaseVectorSearch/ContentDecoders/PdfContentDecoder.cs @@ -6,7 +6,7 @@ namespace SqlDatabaseVectorSearch.ContentDecoders; public class PdfContentDecoder : IContentDecoder { - public Task DecodeAsync(Stream stream, string contentType) + public Task DecodeAsync(Stream stream, string contentType, CancellationToken cancellationToken = default) { var content = new StringBuilder(); diff --git a/SqlDatabaseVectorSearch/ContentDecoders/TextContentDecoder.cs b/SqlDatabaseVectorSearch/ContentDecoders/TextContentDecoder.cs index d9d6406..7b86637 100644 --- a/SqlDatabaseVectorSearch/ContentDecoders/TextContentDecoder.cs +++ b/SqlDatabaseVectorSearch/ContentDecoders/TextContentDecoder.cs @@ -2,10 +2,10 @@ public class TextContentDecoder : IContentDecoder { - public async Task DecodeAsync(Stream stream, string contentType) + public async Task DecodeAsync(Stream stream, string contentType, CancellationToken cancellationToken = default) { using var readStream = new StreamReader(stream); - var content = await readStream.ReadToEndAsync(); + var content = await readStream.ReadToEndAsync(cancellationToken); return content; } diff --git a/SqlDatabaseVectorSearch/Program.cs b/SqlDatabaseVectorSearch/Program.cs index 8a6c94b..469cd17 100644 --- a/SqlDatabaseVectorSearch/Program.cs +++ b/SqlDatabaseVectorSearch/Program.cs @@ -96,23 +96,23 @@ app.UseSwaggerUI(options => options.SwaggerEndpoint("/openapi/v1.json", builder.Environment.ApplicationName); }); -app.MapPost("/api/ask", async (Question question, VectorSearchService vectorSearchService, +app.MapPost("/api/ask", async (Question question, VectorSearchService vectorSearchService, CancellationToken cancellationToken, [Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) => { - var response = await vectorSearchService.AskQuestionAsync(question, reformulate); + var response = await vectorSearchService.AskQuestionAsync(question, reformulate, cancellationToken); return TypedResults.Ok(response); }) .WithSummary("Asks a question") .WithDescription("The question will be reformulated taking into account the context of the chat identified by the given ConversationId.") .WithTags("Ask"); -app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService, +app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService, CancellationToken cancellationToken, [Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) => { async IAsyncEnumerable Stream() { // Requests a streaming response. - var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate); + var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate, cancellationToken); await foreach (var delta in responseStream) { @@ -128,24 +128,24 @@ app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vector var documentsApiGroup = app.MapGroup("/api/documents").WithTags("Documents"); -documentsApiGroup.MapGet(string.Empty, async (DocumentService documentService) => +documentsApiGroup.MapGet(string.Empty, async (DocumentService documentService, CancellationToken cancellationToken) => { - var documents = await documentService.GetDocumentsAsync(); + var documents = await documentService.GetDocumentsAsync(cancellationToken); return TypedResults.Ok(documents); }) .WithSummary("Gets the list of documents"); -documentsApiGroup.MapGet("{documentId:guid}/chunks", async (Guid documentId, DocumentService documentService) => +documentsApiGroup.MapGet("{documentId:guid}/chunks", async (Guid documentId, DocumentService documentService, CancellationToken cancellationToken) => { - var documents = await documentService.GetDocumentChunksAsync(documentId); + var documents = await documentService.GetDocumentChunksAsync(documentId, cancellationToken); return TypedResults.Ok(documents); }) .WithSummary("Gets the list of chunks of a given document") .WithDescription("The list does not contain embedding. Use '/api/documents/{documentId}/chunks/{documentChunkId}' to get the embedding for a given chunk."); -documentsApiGroup.MapGet("{documentId:guid}/chunks/{documentChunkId:guid}", async Task, NotFound>> (Guid documentId, Guid documentChunkId, DocumentService documentService) => +documentsApiGroup.MapGet("{documentId:guid}/chunks/{documentChunkId:guid}", async Task, NotFound>> (Guid documentId, Guid documentChunkId, DocumentService documentService, CancellationToken cancellationToken) => { - var chunk = await documentService.GetDocumentChunkEmbeddingAsync(documentId, documentChunkId); + var chunk = await documentService.GetDocumentChunkEmbeddingAsync(documentId, documentChunkId, cancellationToken); if (chunk is null) { return TypedResults.NotFound(); @@ -156,19 +156,19 @@ documentsApiGroup.MapGet("{documentId:guid}/chunks/{documentChunkId:guid}", asyn .ProducesProblem(StatusCodes.Status404NotFound) .WithSummary("Gets the details of a given chunk, includings its embedding"); -documentsApiGroup.MapDelete("{documentId:guid}", async (Guid documentId, DocumentService documentService) => +documentsApiGroup.MapDelete("{documentId:guid}", async (Guid documentId, DocumentService documentService, CancellationToken cancellationToken) => { - await documentService.DeleteDocumentAsync(documentId); + await documentService.DeleteDocumentAsync(documentId, cancellationToken); return TypedResults.NoContent(); }) .WithSummary("Deletes a document") .WithDescription("This endpoint deletes the document and all its chunks."); -documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchService vectorSearchService, +documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchService vectorSearchService, CancellationToken cancellationToken, [Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) => { using var stream = file.OpenReadStream(); - var response = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId); + var response = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId, cancellationToken); return TypedResults.Ok(response); }) diff --git a/SqlDatabaseVectorSearch/Services/ChatService.cs b/SqlDatabaseVectorSearch/Services/ChatService.cs index f2e6ede..172cd96 100644 --- a/SqlDatabaseVectorSearch/Services/ChatService.cs +++ b/SqlDatabaseVectorSearch/Services/ChatService.cs @@ -1,4 +1,5 @@ -using System.Text; +using System.Runtime.CompilerServices; +using System.Text; using Microsoft.Extensions.Caching.Hybrid; using Microsoft.Extensions.Options; using Microsoft.SemanticKernel.ChatCompletion; @@ -13,9 +14,9 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer { private readonly AppSettings appSettings = appSettingsOptions.Value; - public async Task CreateQuestionAsync(Guid conversationId, string question) + public async Task CreateQuestionAsync(Guid conversationId, string question, CancellationToken cancellationToken = default) { - var chat = await GetChatHistoryAsync(conversationId); + var chat = await GetChatHistoryAsync(conversationId, cancellationToken); var embeddingQuestion = $""" Reformulate the following question taking into account the context of the chat to perform embeddings search: @@ -28,34 +29,34 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer chat.AddUserMessage(embeddingQuestion); - var reformulatedQuestion = await chatCompletionService.GetChatMessageContentAsync(chat)!; + var reformulatedQuestion = await chatCompletionService.GetChatMessageContentAsync(chat, cancellationToken: cancellationToken); chat.AddAssistantMessage(reformulatedQuestion.Content!); - await UpdateCacheAsync(conversationId, chat); + await UpdateCacheAsync(conversationId, chat, cancellationToken); var tokenUsage = GetTokenUsage(reformulatedQuestion); return new(reformulatedQuestion.Content!, tokenUsage); } - public async Task AskQuestionAsync(Guid conversationId, IEnumerable chunks, string question) + public async Task AskQuestionAsync(Guid conversationId, IEnumerable chunks, string question, CancellationToken cancellationToken = default) { var chat = CreateChatAsync(chunks, question); var answer = await chatCompletionService.GetChatMessageContentAsync(chat, new AzureOpenAIPromptExecutionSettings { MaxTokens = appSettings.MaxOutputTokens - }); + }, cancellationToken: cancellationToken); // Add question and answer to the chat history. - await SetChatHistoryAsync(conversationId, question, answer.Content!); + await SetChatHistoryAsync(conversationId, question, answer.Content!, cancellationToken); var tokenUsage = GetTokenUsage(answer); return new(answer.Content!, tokenUsage); } - public async IAsyncEnumerable AskStreamingAsync(Guid conversationId, IEnumerable chunks, string question) + public async IAsyncEnumerable AskStreamingAsync(Guid conversationId, IEnumerable chunks, string question, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var chat = CreateChatAsync(chunks, question); @@ -63,7 +64,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer await foreach (var token in chatCompletionService.GetStreamingChatMessageContentsAsync(chat, new AzureOpenAIPromptExecutionSettings { MaxTokens = appSettings.MaxOutputTokens - })) + }, cancellationToken: cancellationToken)) { if (!string.IsNullOrEmpty(token.Content)) { @@ -82,7 +83,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer } // Add question and answer to the chat history. - await SetChatHistoryAsync(conversationId, question, answer.ToString()); + await SetChatHistoryAsync(conversationId, question, answer.ToString(), cancellationToken); } private static TokenUsage? GetTokenUsage(Microsoft.SemanticKernel.ChatMessageContent message) @@ -153,28 +154,27 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer return chat; } - private async Task UpdateCacheAsync(Guid conversationId, ChatHistory chat) - => await cache.SetAsync(conversationId.ToString(), chat); + private async Task UpdateCacheAsync(Guid conversationId, ChatHistory chat, CancellationToken cancellationToken) + => await cache.SetAsync(conversationId.ToString(), chat, cancellationToken: cancellationToken); - private async Task GetChatHistoryAsync(Guid conversationId) + private async Task GetChatHistoryAsync(Guid conversationId, CancellationToken cancellationToken) { - var historyCache = await cache.GetOrCreateAsync(conversationId.ToString(), - (cancellationToken) => + var historyCache = await cache.GetOrCreateAsync(conversationId.ToString(), (cancellationToken) => { return ValueTask.FromResult([]); - }); + }, cancellationToken: cancellationToken); var chat = new ChatHistory(historyCache); return chat; } - private async Task SetChatHistoryAsync(Guid conversationId, string question, string answer) + private async Task SetChatHistoryAsync(Guid conversationId, string question, string answer, CancellationToken cancellationToken) { - var history = await GetChatHistoryAsync(conversationId); + var history = await GetChatHistoryAsync(conversationId, cancellationToken); history.AddUserMessage(question); history.AddAssistantMessage(answer); - await UpdateCacheAsync(conversationId, history); + await UpdateCacheAsync(conversationId, history, cancellationToken); } } diff --git a/SqlDatabaseVectorSearch/Services/DocumentService.cs b/SqlDatabaseVectorSearch/Services/DocumentService.cs index b6a7ed5..f982974 100644 --- a/SqlDatabaseVectorSearch/Services/DocumentService.cs +++ b/SqlDatabaseVectorSearch/Services/DocumentService.cs @@ -7,33 +7,33 @@ namespace SqlDatabaseVectorSearch.Services; public class DocumentService(ApplicationDbContext dbContext) { - public async Task> GetDocumentsAsync() + public async Task> GetDocumentsAsync(CancellationToken cancellationToken = default) { var documents = await dbContext.Documents.OrderBy(d => d.Name) .Select(d => new Document(d.Id, d.Name, d.CreationDate, d.Chunks.Count)) - .ToListAsync(); + .ToListAsync(cancellationToken: cancellationToken); return documents; } - public async Task> GetDocumentChunksAsync(Guid documentId) + public async Task> GetDocumentChunksAsync(Guid documentId, CancellationToken cancellationToken = default) { var documentChunks = await dbContext.DocumentChunks.Where(c => c.DocumentId == documentId).OrderBy(c => c.Index) .Select(c => new DocumentChunk(c.Id, c.Index, c.Content, null)) - .ToListAsync(); + .ToListAsync(cancellationToken: cancellationToken); return documentChunks; } - public async Task GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId) + public async Task GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId, CancellationToken cancellationToken = default) { var documentChunk = await dbContext.DocumentChunks.Where(c => c.Id == documentChunkId && c.DocumentId == documentId) .Select(c => new DocumentChunk(c.Id, c.Index, c.Content, c.Embedding)) - .FirstOrDefaultAsync(); + .FirstOrDefaultAsync(cancellationToken: cancellationToken); return documentChunk; } - public Task DeleteDocumentAsync(Guid documentId) - => dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync(); + public Task DeleteDocumentAsync(Guid documentId, CancellationToken cancellationToken = default) + => dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync(cancellationToken: cancellationToken); } \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs index 74aa741..c4c44e1 100644 --- a/SqlDatabaseVectorSearch/Services/VectorSearchService.cs +++ b/SqlDatabaseVectorSearch/Services/VectorSearchService.cs @@ -1,4 +1,6 @@ using System.Data; +using System.Runtime.CompilerServices; +using System.Threading; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Options; using Microsoft.SemanticKernel.Embeddings; @@ -14,24 +16,24 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb { private readonly AppSettings appSettings = appSettingsOptions.Value; - public async Task ImportAsync(Stream stream, string name, string contentType, Guid? documentId) + public async Task ImportAsync(Stream stream, string name, string contentType, Guid? documentId, CancellationToken cancellationToken = default) { // Extract the contents of the file. var decoder = serviceProvider.GetKeyedService(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported."); - var content = await decoder.DecodeAsync(stream, contentType); + var content = await decoder.DecodeAsync(stream, contentType, cancellationToken); // We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis). var tokenCount = tokenizerService.CountEmbeddingTokens(content); var strategy = dbContext.Database.CreateExecutionStrategy(); - var document = await strategy.ExecuteAsync(async () => + var document = await strategy.ExecuteAsync(async (cancellationToken) => { await dbContext.Database.BeginTransactionAsync(); if (documentId.HasValue) { // If the user is importing a document that already exists, delete the previous one. - await documentService.DeleteDocumentAsync(documentId.Value); + await documentService.DeleteDocumentAsync(documentId.Value, cancellationToken); } var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() }; @@ -39,7 +41,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb // Split the content into chunks and generate the embeddings for each one. var paragraphs = textChunkerService.Split(content); - var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(paragraphs); + var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(paragraphs, cancellationToken: cancellationToken); // Save the document chunks and the corresponding embedding in the database. foreach (var (index, paragraph) in paragraphs.Index()) @@ -50,31 +52,31 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb dbContext.DocumentChunks.Add(documentChunk); } - await dbContext.SaveChangesAsync(); - await dbContext.Database.CommitTransactionAsync(); + await dbContext.SaveChangesAsync(cancellationToken); + await dbContext.Database.CommitTransactionAsync(cancellationToken); return document; - }); + }, cancellationToken); return new(document.Id, tokenCount); } - public async Task AskQuestionAsync(Question question, bool reformulate = true) + public async Task AskQuestionAsync(Question question, bool reformulate = true, CancellationToken cancellationToken = default) { // It the user doesn't want to reforulate the question, CreateContextAsync returns the original one. - var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate); + var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken); - var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!); + var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!, cancellationToken); return new(question.Text, reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage)); } - public async IAsyncEnumerable AskStreamingAsync(Question question, bool reformulate = true) + public async IAsyncEnumerable AskStreamingAsync(Question question, bool reformulate = true, [EnumeratorCancellation] CancellationToken cancellationToken = default) { // It the user doesn't want to reforulate the question, CreateContextAsync returns the original one. - var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate); + var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken); - var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!); + var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!, cancellationToken: cancellationToken); // The first message contains the question and the corresponding token usage (if reformulated). yield return new(question.Text, reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null)); @@ -96,20 +98,20 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb } } - private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable Chunks)> CreateContextAsync(Question question, bool reformulate = true) + private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable Chunks)> CreateContextAsync(Question question, bool reformulate, CancellationToken cancellationToken) { // Reformulate the question taking into account the context of the chat to perform keyword search and embeddings. - var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : new(question.Text); + var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text, cancellationToken) : new(question.Text); var embeddingTokenCount = tokenizerService.CountEmbeddingTokens(reformulatedQuestion.Text!); // Perform Vector Search on SQL Database. - var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!); + var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!, cancellationToken: cancellationToken); var chunks = await dbContext.DocumentChunks .OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray())) .Select(c => c.Content) .Take(appSettings.MaxRelevantChunks) - .ToListAsync(); + .ToListAsync(cancellationToken); return (reformulatedQuestion, embeddingTokenCount, chunks); }