Add CancellationToken support to async methods #9

Introduce support for `CancellationToken` across various methods to allow for task cancellation and improve responsiveness.
- Update `DecodeAsync` method in `DocxContentDecoder.cs`, `PdfContentDecoder.cs`, `TextContentDecoder.cs`, and `IContentDecoder.cs` to include an optional `CancellationToken` parameter.
- Modify endpoint handlers in `Program.cs` to accept and pass `CancellationToken` parameters.
- Update methods in `ChatService.cs` to include `CancellationToken` parameters.
- Update methods in `DocumentService.cs` to include `CancellationToken` parameters.
- Update methods in `VectorSearchService.cs` to include `CancellationToken` parameters.
These changes ensure that long-running operations can be canceled if needed, improving the application's ability to handle cancellation requests gracefully.
This commit is contained in:
Marco Minerva
2025-02-10 16:20:35 +01:00
parent d0fce6ffd2
commit a0f1755c85
8 changed files with 67 additions and 65 deletions
@@ -6,7 +6,7 @@ namespace SqlDatabaseVectorSearch.ContentDecoders;
public class DocxContentDecoder : IContentDecoder public class DocxContentDecoder : IContentDecoder
{ {
public Task<string> DecodeAsync(Stream stream, string contentType) public Task<string> DecodeAsync(Stream stream, string contentType, CancellationToken cancellationToken = default)
{ {
// Open a Word document for read-only access. // Open a Word document for read-only access.
using var document = WordprocessingDocument.Open(stream, false); using var document = WordprocessingDocument.Open(stream, false);
@@ -2,5 +2,5 @@
public interface IContentDecoder public interface IContentDecoder
{ {
Task<string> DecodeAsync(Stream stream, string contentType); Task<string> DecodeAsync(Stream stream, string contentType, CancellationToken cancellationToken = default);
} }
@@ -6,7 +6,7 @@ namespace SqlDatabaseVectorSearch.ContentDecoders;
public class PdfContentDecoder : IContentDecoder public class PdfContentDecoder : IContentDecoder
{ {
public Task<string> DecodeAsync(Stream stream, string contentType) public Task<string> DecodeAsync(Stream stream, string contentType, CancellationToken cancellationToken = default)
{ {
var content = new StringBuilder(); var content = new StringBuilder();
@@ -2,10 +2,10 @@
public class TextContentDecoder : IContentDecoder public class TextContentDecoder : IContentDecoder
{ {
public async Task<string> DecodeAsync(Stream stream, string contentType) public async Task<string> DecodeAsync(Stream stream, string contentType, CancellationToken cancellationToken = default)
{ {
using var readStream = new StreamReader(stream); using var readStream = new StreamReader(stream);
var content = await readStream.ReadToEndAsync(); var content = await readStream.ReadToEndAsync(cancellationToken);
return content; return content;
} }
+14 -14
View File
@@ -96,23 +96,23 @@ app.UseSwaggerUI(options =>
options.SwaggerEndpoint("/openapi/v1.json", builder.Environment.ApplicationName); options.SwaggerEndpoint("/openapi/v1.json", builder.Environment.ApplicationName);
}); });
app.MapPost("/api/ask", async (Question question, VectorSearchService vectorSearchService, app.MapPost("/api/ask", async (Question question, VectorSearchService vectorSearchService, CancellationToken cancellationToken,
[Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) => [Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) =>
{ {
var response = await vectorSearchService.AskQuestionAsync(question, reformulate); var response = await vectorSearchService.AskQuestionAsync(question, reformulate, cancellationToken);
return TypedResults.Ok(response); return TypedResults.Ok(response);
}) })
.WithSummary("Asks a question") .WithSummary("Asks a question")
.WithDescription("The question will be reformulated taking into account the context of the chat identified by the given ConversationId.") .WithDescription("The question will be reformulated taking into account the context of the chat identified by the given ConversationId.")
.WithTags("Ask"); .WithTags("Ask");
app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService, app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService, CancellationToken cancellationToken,
[Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) => [Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) =>
{ {
async IAsyncEnumerable<QuestionResponse> Stream() async IAsyncEnumerable<QuestionResponse> Stream()
{ {
// Requests a streaming response. // Requests a streaming response.
var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate); var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate, cancellationToken);
await foreach (var delta in responseStream) await foreach (var delta in responseStream)
{ {
@@ -128,24 +128,24 @@ app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vector
var documentsApiGroup = app.MapGroup("/api/documents").WithTags("Documents"); var documentsApiGroup = app.MapGroup("/api/documents").WithTags("Documents");
documentsApiGroup.MapGet(string.Empty, async (DocumentService documentService) => documentsApiGroup.MapGet(string.Empty, async (DocumentService documentService, CancellationToken cancellationToken) =>
{ {
var documents = await documentService.GetDocumentsAsync(); var documents = await documentService.GetDocumentsAsync(cancellationToken);
return TypedResults.Ok(documents); return TypedResults.Ok(documents);
}) })
.WithSummary("Gets the list of documents"); .WithSummary("Gets the list of documents");
documentsApiGroup.MapGet("{documentId:guid}/chunks", async (Guid documentId, DocumentService documentService) => documentsApiGroup.MapGet("{documentId:guid}/chunks", async (Guid documentId, DocumentService documentService, CancellationToken cancellationToken) =>
{ {
var documents = await documentService.GetDocumentChunksAsync(documentId); var documents = await documentService.GetDocumentChunksAsync(documentId, cancellationToken);
return TypedResults.Ok(documents); return TypedResults.Ok(documents);
}) })
.WithSummary("Gets the list of chunks of a given document") .WithSummary("Gets the list of chunks of a given document")
.WithDescription("The list does not contain embedding. Use '/api/documents/{documentId}/chunks/{documentChunkId}' to get the embedding for a given chunk."); .WithDescription("The list does not contain embedding. Use '/api/documents/{documentId}/chunks/{documentChunkId}' to get the embedding for a given chunk.");
documentsApiGroup.MapGet("{documentId:guid}/chunks/{documentChunkId:guid}", async Task<Results<Ok<DocumentChunk>, NotFound>> (Guid documentId, Guid documentChunkId, DocumentService documentService) => documentsApiGroup.MapGet("{documentId:guid}/chunks/{documentChunkId:guid}", async Task<Results<Ok<DocumentChunk>, NotFound>> (Guid documentId, Guid documentChunkId, DocumentService documentService, CancellationToken cancellationToken) =>
{ {
var chunk = await documentService.GetDocumentChunkEmbeddingAsync(documentId, documentChunkId); var chunk = await documentService.GetDocumentChunkEmbeddingAsync(documentId, documentChunkId, cancellationToken);
if (chunk is null) if (chunk is null)
{ {
return TypedResults.NotFound(); return TypedResults.NotFound();
@@ -156,19 +156,19 @@ documentsApiGroup.MapGet("{documentId:guid}/chunks/{documentChunkId:guid}", asyn
.ProducesProblem(StatusCodes.Status404NotFound) .ProducesProblem(StatusCodes.Status404NotFound)
.WithSummary("Gets the details of a given chunk, includings its embedding"); .WithSummary("Gets the details of a given chunk, includings its embedding");
documentsApiGroup.MapDelete("{documentId:guid}", async (Guid documentId, DocumentService documentService) => documentsApiGroup.MapDelete("{documentId:guid}", async (Guid documentId, DocumentService documentService, CancellationToken cancellationToken) =>
{ {
await documentService.DeleteDocumentAsync(documentId); await documentService.DeleteDocumentAsync(documentId, cancellationToken);
return TypedResults.NoContent(); return TypedResults.NoContent();
}) })
.WithSummary("Deletes a document") .WithSummary("Deletes a document")
.WithDescription("This endpoint deletes the document and all its chunks."); .WithDescription("This endpoint deletes the document and all its chunks.");
documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchService vectorSearchService, documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchService vectorSearchService, CancellationToken cancellationToken,
[Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) => [Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) =>
{ {
using var stream = file.OpenReadStream(); using var stream = file.OpenReadStream();
var response = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId); var response = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId, cancellationToken);
return TypedResults.Ok(response); return TypedResults.Ok(response);
}) })
+20 -20
View File
@@ -1,4 +1,5 @@
using System.Text; using System.Runtime.CompilerServices;
using System.Text;
using Microsoft.Extensions.Caching.Hybrid; using Microsoft.Extensions.Caching.Hybrid;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.ChatCompletion;
@@ -13,9 +14,9 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
{ {
private readonly AppSettings appSettings = appSettingsOptions.Value; private readonly AppSettings appSettings = appSettingsOptions.Value;
public async Task<ChatResponse> CreateQuestionAsync(Guid conversationId, string question) public async Task<ChatResponse> CreateQuestionAsync(Guid conversationId, string question, CancellationToken cancellationToken = default)
{ {
var chat = await GetChatHistoryAsync(conversationId); var chat = await GetChatHistoryAsync(conversationId, cancellationToken);
var embeddingQuestion = $""" var embeddingQuestion = $"""
Reformulate the following question taking into account the context of the chat to perform embeddings search: Reformulate the following question taking into account the context of the chat to perform embeddings search:
@@ -28,34 +29,34 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
chat.AddUserMessage(embeddingQuestion); chat.AddUserMessage(embeddingQuestion);
var reformulatedQuestion = await chatCompletionService.GetChatMessageContentAsync(chat)!; var reformulatedQuestion = await chatCompletionService.GetChatMessageContentAsync(chat, cancellationToken: cancellationToken);
chat.AddAssistantMessage(reformulatedQuestion.Content!); chat.AddAssistantMessage(reformulatedQuestion.Content!);
await UpdateCacheAsync(conversationId, chat); await UpdateCacheAsync(conversationId, chat, cancellationToken);
var tokenUsage = GetTokenUsage(reformulatedQuestion); var tokenUsage = GetTokenUsage(reformulatedQuestion);
return new(reformulatedQuestion.Content!, tokenUsage); return new(reformulatedQuestion.Content!, tokenUsage);
} }
public async Task<ChatResponse> AskQuestionAsync(Guid conversationId, IEnumerable<string> chunks, string question) public async Task<ChatResponse> AskQuestionAsync(Guid conversationId, IEnumerable<string> chunks, string question, CancellationToken cancellationToken = default)
{ {
var chat = CreateChatAsync(chunks, question); var chat = CreateChatAsync(chunks, question);
var answer = await chatCompletionService.GetChatMessageContentAsync(chat, new AzureOpenAIPromptExecutionSettings var answer = await chatCompletionService.GetChatMessageContentAsync(chat, new AzureOpenAIPromptExecutionSettings
{ {
MaxTokens = appSettings.MaxOutputTokens MaxTokens = appSettings.MaxOutputTokens
}); }, cancellationToken: cancellationToken);
// Add question and answer to the chat history. // Add question and answer to the chat history.
await SetChatHistoryAsync(conversationId, question, answer.Content!); await SetChatHistoryAsync(conversationId, question, answer.Content!, cancellationToken);
var tokenUsage = GetTokenUsage(answer); var tokenUsage = GetTokenUsage(answer);
return new(answer.Content!, tokenUsage); return new(answer.Content!, tokenUsage);
} }
public async IAsyncEnumerable<ChatResponse> AskStreamingAsync(Guid conversationId, IEnumerable<string> chunks, string question) public async IAsyncEnumerable<ChatResponse> AskStreamingAsync(Guid conversationId, IEnumerable<string> chunks, string question, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{ {
var chat = CreateChatAsync(chunks, question); var chat = CreateChatAsync(chunks, question);
@@ -63,7 +64,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
await foreach (var token in chatCompletionService.GetStreamingChatMessageContentsAsync(chat, new AzureOpenAIPromptExecutionSettings await foreach (var token in chatCompletionService.GetStreamingChatMessageContentsAsync(chat, new AzureOpenAIPromptExecutionSettings
{ {
MaxTokens = appSettings.MaxOutputTokens MaxTokens = appSettings.MaxOutputTokens
})) }, cancellationToken: cancellationToken))
{ {
if (!string.IsNullOrEmpty(token.Content)) if (!string.IsNullOrEmpty(token.Content))
{ {
@@ -82,7 +83,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
} }
// Add question and answer to the chat history. // Add question and answer to the chat history.
await SetChatHistoryAsync(conversationId, question, answer.ToString()); await SetChatHistoryAsync(conversationId, question, answer.ToString(), cancellationToken);
} }
private static TokenUsage? GetTokenUsage(Microsoft.SemanticKernel.ChatMessageContent message) private static TokenUsage? GetTokenUsage(Microsoft.SemanticKernel.ChatMessageContent message)
@@ -153,28 +154,27 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
return chat; return chat;
} }
private async Task UpdateCacheAsync(Guid conversationId, ChatHistory chat) private async Task UpdateCacheAsync(Guid conversationId, ChatHistory chat, CancellationToken cancellationToken)
=> await cache.SetAsync(conversationId.ToString(), chat); => await cache.SetAsync(conversationId.ToString(), chat, cancellationToken: cancellationToken);
private async Task<ChatHistory> GetChatHistoryAsync(Guid conversationId) private async Task<ChatHistory> GetChatHistoryAsync(Guid conversationId, CancellationToken cancellationToken)
{ {
var historyCache = await cache.GetOrCreateAsync(conversationId.ToString(), var historyCache = await cache.GetOrCreateAsync(conversationId.ToString(), (cancellationToken) =>
(cancellationToken) =>
{ {
return ValueTask.FromResult<ChatHistory>([]); return ValueTask.FromResult<ChatHistory>([]);
}); }, cancellationToken: cancellationToken);
var chat = new ChatHistory(historyCache); var chat = new ChatHistory(historyCache);
return chat; return chat;
} }
private async Task SetChatHistoryAsync(Guid conversationId, string question, string answer) private async Task SetChatHistoryAsync(Guid conversationId, string question, string answer, CancellationToken cancellationToken)
{ {
var history = await GetChatHistoryAsync(conversationId); var history = await GetChatHistoryAsync(conversationId, cancellationToken);
history.AddUserMessage(question); history.AddUserMessage(question);
history.AddAssistantMessage(answer); history.AddAssistantMessage(answer);
await UpdateCacheAsync(conversationId, history); await UpdateCacheAsync(conversationId, history, cancellationToken);
} }
} }
@@ -7,33 +7,33 @@ namespace SqlDatabaseVectorSearch.Services;
public class DocumentService(ApplicationDbContext dbContext) public class DocumentService(ApplicationDbContext dbContext)
{ {
public async Task<IEnumerable<Document>> GetDocumentsAsync() public async Task<IEnumerable<Document>> GetDocumentsAsync(CancellationToken cancellationToken = default)
{ {
var documents = await dbContext.Documents.OrderBy(d => d.Name) var documents = await dbContext.Documents.OrderBy(d => d.Name)
.Select(d => new Document(d.Id, d.Name, d.CreationDate, d.Chunks.Count)) .Select(d => new Document(d.Id, d.Name, d.CreationDate, d.Chunks.Count))
.ToListAsync(); .ToListAsync(cancellationToken: cancellationToken);
return documents; return documents;
} }
public async Task<IEnumerable<DocumentChunk>> GetDocumentChunksAsync(Guid documentId) public async Task<IEnumerable<DocumentChunk>> GetDocumentChunksAsync(Guid documentId, CancellationToken cancellationToken = default)
{ {
var documentChunks = await dbContext.DocumentChunks.Where(c => c.DocumentId == documentId).OrderBy(c => c.Index) var documentChunks = await dbContext.DocumentChunks.Where(c => c.DocumentId == documentId).OrderBy(c => c.Index)
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, null)) .Select(c => new DocumentChunk(c.Id, c.Index, c.Content, null))
.ToListAsync(); .ToListAsync(cancellationToken: cancellationToken);
return documentChunks; return documentChunks;
} }
public async Task<DocumentChunk?> GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId) public async Task<DocumentChunk?> GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId, CancellationToken cancellationToken = default)
{ {
var documentChunk = await dbContext.DocumentChunks.Where(c => c.Id == documentChunkId && c.DocumentId == documentId) var documentChunk = await dbContext.DocumentChunks.Where(c => c.Id == documentChunkId && c.DocumentId == documentId)
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, c.Embedding)) .Select(c => new DocumentChunk(c.Id, c.Index, c.Content, c.Embedding))
.FirstOrDefaultAsync(); .FirstOrDefaultAsync(cancellationToken: cancellationToken);
return documentChunk; return documentChunk;
} }
public Task DeleteDocumentAsync(Guid documentId) public Task DeleteDocumentAsync(Guid documentId, CancellationToken cancellationToken = default)
=> dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync(); => dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync(cancellationToken: cancellationToken);
} }
@@ -1,4 +1,6 @@
using System.Data; using System.Data;
using System.Runtime.CompilerServices;
using System.Threading;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel.Embeddings; using Microsoft.SemanticKernel.Embeddings;
@@ -14,24 +16,24 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
{ {
private readonly AppSettings appSettings = appSettingsOptions.Value; private readonly AppSettings appSettings = appSettingsOptions.Value;
public async Task<ImportDocumentResponse> ImportAsync(Stream stream, string name, string contentType, Guid? documentId) public async Task<ImportDocumentResponse> ImportAsync(Stream stream, string name, string contentType, Guid? documentId, CancellationToken cancellationToken = default)
{ {
// Extract the contents of the file. // Extract the contents of the file.
var decoder = serviceProvider.GetKeyedService<IContentDecoder>(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported."); var decoder = serviceProvider.GetKeyedService<IContentDecoder>(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported.");
var content = await decoder.DecodeAsync(stream, contentType); var content = await decoder.DecodeAsync(stream, contentType, cancellationToken);
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis). // We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
var tokenCount = tokenizerService.CountEmbeddingTokens(content); var tokenCount = tokenizerService.CountEmbeddingTokens(content);
var strategy = dbContext.Database.CreateExecutionStrategy(); var strategy = dbContext.Database.CreateExecutionStrategy();
var document = await strategy.ExecuteAsync(async () => var document = await strategy.ExecuteAsync(async (cancellationToken) =>
{ {
await dbContext.Database.BeginTransactionAsync(); await dbContext.Database.BeginTransactionAsync();
if (documentId.HasValue) if (documentId.HasValue)
{ {
// If the user is importing a document that already exists, delete the previous one. // If the user is importing a document that already exists, delete the previous one.
await documentService.DeleteDocumentAsync(documentId.Value); await documentService.DeleteDocumentAsync(documentId.Value, cancellationToken);
} }
var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() }; var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() };
@@ -39,7 +41,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
// Split the content into chunks and generate the embeddings for each one. // Split the content into chunks and generate the embeddings for each one.
var paragraphs = textChunkerService.Split(content); var paragraphs = textChunkerService.Split(content);
var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(paragraphs); var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(paragraphs, cancellationToken: cancellationToken);
// Save the document chunks and the corresponding embedding in the database. // Save the document chunks and the corresponding embedding in the database.
foreach (var (index, paragraph) in paragraphs.Index()) foreach (var (index, paragraph) in paragraphs.Index())
@@ -50,31 +52,31 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
dbContext.DocumentChunks.Add(documentChunk); dbContext.DocumentChunks.Add(documentChunk);
} }
await dbContext.SaveChangesAsync(); await dbContext.SaveChangesAsync(cancellationToken);
await dbContext.Database.CommitTransactionAsync(); await dbContext.Database.CommitTransactionAsync(cancellationToken);
return document; return document;
}); }, cancellationToken);
return new(document.Id, tokenCount); return new(document.Id, tokenCount);
} }
public async Task<QuestionResponse> AskQuestionAsync(Question question, bool reformulate = true) public async Task<QuestionResponse> AskQuestionAsync(Question question, bool reformulate = true, CancellationToken cancellationToken = default)
{ {
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one. // It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate); var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken);
var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!); var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!, cancellationToken);
return new(question.Text, reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage)); return new(question.Text, reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage));
} }
public async IAsyncEnumerable<QuestionResponse> AskStreamingAsync(Question question, bool reformulate = true) public async IAsyncEnumerable<QuestionResponse> AskStreamingAsync(Question question, bool reformulate = true, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{ {
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one. // It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate); var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken);
var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!); var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!, cancellationToken: cancellationToken);
// The first message contains the question and the corresponding token usage (if reformulated). // The first message contains the question and the corresponding token usage (if reformulated).
yield return new(question.Text, reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null)); yield return new(question.Text, reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null));
@@ -96,20 +98,20 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
} }
} }
private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate = true) private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate, CancellationToken cancellationToken)
{ {
// Reformulate the question taking into account the context of the chat to perform keyword search and embeddings. // Reformulate the question taking into account the context of the chat to perform keyword search and embeddings.
var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : new(question.Text); var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text, cancellationToken) : new(question.Text);
var embeddingTokenCount = tokenizerService.CountEmbeddingTokens(reformulatedQuestion.Text!); var embeddingTokenCount = tokenizerService.CountEmbeddingTokens(reformulatedQuestion.Text!);
// Perform Vector Search on SQL Database. // Perform Vector Search on SQL Database.
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!); var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!, cancellationToken: cancellationToken);
var chunks = await dbContext.DocumentChunks var chunks = await dbContext.DocumentChunks
.OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray())) .OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray()))
.Select(c => c.Content) .Select(c => c.Content)
.Take(appSettings.MaxRelevantChunks) .Take(appSettings.MaxRelevantChunks)
.ToListAsync(); .ToListAsync(cancellationToken);
return (reformulatedQuestion, embeddingTokenCount, chunks); return (reformulatedQuestion, embeddingTokenCount, chunks);
} }