mirror of
https://github.com/marcominerva/SqlDatabaseVectorSearch.git
synced 2026-06-20 12:23:10 +00:00
Add CancellationToken support to async methods #9
Introduce support for `CancellationToken` across various methods to allow for task cancellation and improve responsiveness. - Update `DecodeAsync` method in `DocxContentDecoder.cs`, `PdfContentDecoder.cs`, `TextContentDecoder.cs`, and `IContentDecoder.cs` to include an optional `CancellationToken` parameter. - Modify endpoint handlers in `Program.cs` to accept and pass `CancellationToken` parameters. - Update methods in `ChatService.cs` to include `CancellationToken` parameters. - Update methods in `DocumentService.cs` to include `CancellationToken` parameters. - Update methods in `VectorSearchService.cs` to include `CancellationToken` parameters. These changes ensure that long-running operations can be canceled if needed, improving the application's ability to handle cancellation requests gracefully.
This commit is contained in:
@@ -6,7 +6,7 @@ namespace SqlDatabaseVectorSearch.ContentDecoders;
|
|||||||
|
|
||||||
public class DocxContentDecoder : IContentDecoder
|
public class DocxContentDecoder : IContentDecoder
|
||||||
{
|
{
|
||||||
public Task<string> DecodeAsync(Stream stream, string contentType)
|
public Task<string> DecodeAsync(Stream stream, string contentType, CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
// Open a Word document for read-only access.
|
// Open a Word document for read-only access.
|
||||||
using var document = WordprocessingDocument.Open(stream, false);
|
using var document = WordprocessingDocument.Open(stream, false);
|
||||||
|
|||||||
@@ -2,5 +2,5 @@
|
|||||||
|
|
||||||
public interface IContentDecoder
|
public interface IContentDecoder
|
||||||
{
|
{
|
||||||
Task<string> DecodeAsync(Stream stream, string contentType);
|
Task<string> DecodeAsync(Stream stream, string contentType, CancellationToken cancellationToken = default);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ namespace SqlDatabaseVectorSearch.ContentDecoders;
|
|||||||
|
|
||||||
public class PdfContentDecoder : IContentDecoder
|
public class PdfContentDecoder : IContentDecoder
|
||||||
{
|
{
|
||||||
public Task<string> DecodeAsync(Stream stream, string contentType)
|
public Task<string> DecodeAsync(Stream stream, string contentType, CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
var content = new StringBuilder();
|
var content = new StringBuilder();
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
public class TextContentDecoder : IContentDecoder
|
public class TextContentDecoder : IContentDecoder
|
||||||
{
|
{
|
||||||
public async Task<string> DecodeAsync(Stream stream, string contentType)
|
public async Task<string> DecodeAsync(Stream stream, string contentType, CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
using var readStream = new StreamReader(stream);
|
using var readStream = new StreamReader(stream);
|
||||||
var content = await readStream.ReadToEndAsync();
|
var content = await readStream.ReadToEndAsync(cancellationToken);
|
||||||
|
|
||||||
return content;
|
return content;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -96,23 +96,23 @@ app.UseSwaggerUI(options =>
|
|||||||
options.SwaggerEndpoint("/openapi/v1.json", builder.Environment.ApplicationName);
|
options.SwaggerEndpoint("/openapi/v1.json", builder.Environment.ApplicationName);
|
||||||
});
|
});
|
||||||
|
|
||||||
app.MapPost("/api/ask", async (Question question, VectorSearchService vectorSearchService,
|
app.MapPost("/api/ask", async (Question question, VectorSearchService vectorSearchService, CancellationToken cancellationToken,
|
||||||
[Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) =>
|
[Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) =>
|
||||||
{
|
{
|
||||||
var response = await vectorSearchService.AskQuestionAsync(question, reformulate);
|
var response = await vectorSearchService.AskQuestionAsync(question, reformulate, cancellationToken);
|
||||||
return TypedResults.Ok(response);
|
return TypedResults.Ok(response);
|
||||||
})
|
})
|
||||||
.WithSummary("Asks a question")
|
.WithSummary("Asks a question")
|
||||||
.WithDescription("The question will be reformulated taking into account the context of the chat identified by the given ConversationId.")
|
.WithDescription("The question will be reformulated taking into account the context of the chat identified by the given ConversationId.")
|
||||||
.WithTags("Ask");
|
.WithTags("Ask");
|
||||||
|
|
||||||
app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService,
|
app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService, CancellationToken cancellationToken,
|
||||||
[Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) =>
|
[Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) =>
|
||||||
{
|
{
|
||||||
async IAsyncEnumerable<QuestionResponse> Stream()
|
async IAsyncEnumerable<QuestionResponse> Stream()
|
||||||
{
|
{
|
||||||
// Requests a streaming response.
|
// Requests a streaming response.
|
||||||
var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate);
|
var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate, cancellationToken);
|
||||||
|
|
||||||
await foreach (var delta in responseStream)
|
await foreach (var delta in responseStream)
|
||||||
{
|
{
|
||||||
@@ -128,24 +128,24 @@ app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vector
|
|||||||
|
|
||||||
var documentsApiGroup = app.MapGroup("/api/documents").WithTags("Documents");
|
var documentsApiGroup = app.MapGroup("/api/documents").WithTags("Documents");
|
||||||
|
|
||||||
documentsApiGroup.MapGet(string.Empty, async (DocumentService documentService) =>
|
documentsApiGroup.MapGet(string.Empty, async (DocumentService documentService, CancellationToken cancellationToken) =>
|
||||||
{
|
{
|
||||||
var documents = await documentService.GetDocumentsAsync();
|
var documents = await documentService.GetDocumentsAsync(cancellationToken);
|
||||||
return TypedResults.Ok(documents);
|
return TypedResults.Ok(documents);
|
||||||
})
|
})
|
||||||
.WithSummary("Gets the list of documents");
|
.WithSummary("Gets the list of documents");
|
||||||
|
|
||||||
documentsApiGroup.MapGet("{documentId:guid}/chunks", async (Guid documentId, DocumentService documentService) =>
|
documentsApiGroup.MapGet("{documentId:guid}/chunks", async (Guid documentId, DocumentService documentService, CancellationToken cancellationToken) =>
|
||||||
{
|
{
|
||||||
var documents = await documentService.GetDocumentChunksAsync(documentId);
|
var documents = await documentService.GetDocumentChunksAsync(documentId, cancellationToken);
|
||||||
return TypedResults.Ok(documents);
|
return TypedResults.Ok(documents);
|
||||||
})
|
})
|
||||||
.WithSummary("Gets the list of chunks of a given document")
|
.WithSummary("Gets the list of chunks of a given document")
|
||||||
.WithDescription("The list does not contain embedding. Use '/api/documents/{documentId}/chunks/{documentChunkId}' to get the embedding for a given chunk.");
|
.WithDescription("The list does not contain embedding. Use '/api/documents/{documentId}/chunks/{documentChunkId}' to get the embedding for a given chunk.");
|
||||||
|
|
||||||
documentsApiGroup.MapGet("{documentId:guid}/chunks/{documentChunkId:guid}", async Task<Results<Ok<DocumentChunk>, NotFound>> (Guid documentId, Guid documentChunkId, DocumentService documentService) =>
|
documentsApiGroup.MapGet("{documentId:guid}/chunks/{documentChunkId:guid}", async Task<Results<Ok<DocumentChunk>, NotFound>> (Guid documentId, Guid documentChunkId, DocumentService documentService, CancellationToken cancellationToken) =>
|
||||||
{
|
{
|
||||||
var chunk = await documentService.GetDocumentChunkEmbeddingAsync(documentId, documentChunkId);
|
var chunk = await documentService.GetDocumentChunkEmbeddingAsync(documentId, documentChunkId, cancellationToken);
|
||||||
if (chunk is null)
|
if (chunk is null)
|
||||||
{
|
{
|
||||||
return TypedResults.NotFound();
|
return TypedResults.NotFound();
|
||||||
@@ -156,19 +156,19 @@ documentsApiGroup.MapGet("{documentId:guid}/chunks/{documentChunkId:guid}", asyn
|
|||||||
.ProducesProblem(StatusCodes.Status404NotFound)
|
.ProducesProblem(StatusCodes.Status404NotFound)
|
||||||
.WithSummary("Gets the details of a given chunk, includings its embedding");
|
.WithSummary("Gets the details of a given chunk, includings its embedding");
|
||||||
|
|
||||||
documentsApiGroup.MapDelete("{documentId:guid}", async (Guid documentId, DocumentService documentService) =>
|
documentsApiGroup.MapDelete("{documentId:guid}", async (Guid documentId, DocumentService documentService, CancellationToken cancellationToken) =>
|
||||||
{
|
{
|
||||||
await documentService.DeleteDocumentAsync(documentId);
|
await documentService.DeleteDocumentAsync(documentId, cancellationToken);
|
||||||
return TypedResults.NoContent();
|
return TypedResults.NoContent();
|
||||||
})
|
})
|
||||||
.WithSummary("Deletes a document")
|
.WithSummary("Deletes a document")
|
||||||
.WithDescription("This endpoint deletes the document and all its chunks.");
|
.WithDescription("This endpoint deletes the document and all its chunks.");
|
||||||
|
|
||||||
documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchService vectorSearchService,
|
documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchService vectorSearchService, CancellationToken cancellationToken,
|
||||||
[Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) =>
|
[Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) =>
|
||||||
{
|
{
|
||||||
using var stream = file.OpenReadStream();
|
using var stream = file.OpenReadStream();
|
||||||
var response = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId);
|
var response = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId, cancellationToken);
|
||||||
|
|
||||||
return TypedResults.Ok(response);
|
return TypedResults.Ok(response);
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
using System.Text;
|
using System.Runtime.CompilerServices;
|
||||||
|
using System.Text;
|
||||||
using Microsoft.Extensions.Caching.Hybrid;
|
using Microsoft.Extensions.Caching.Hybrid;
|
||||||
using Microsoft.Extensions.Options;
|
using Microsoft.Extensions.Options;
|
||||||
using Microsoft.SemanticKernel.ChatCompletion;
|
using Microsoft.SemanticKernel.ChatCompletion;
|
||||||
@@ -13,9 +14,9 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
|
|||||||
{
|
{
|
||||||
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
||||||
|
|
||||||
public async Task<ChatResponse> CreateQuestionAsync(Guid conversationId, string question)
|
public async Task<ChatResponse> CreateQuestionAsync(Guid conversationId, string question, CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
var chat = await GetChatHistoryAsync(conversationId);
|
var chat = await GetChatHistoryAsync(conversationId, cancellationToken);
|
||||||
|
|
||||||
var embeddingQuestion = $"""
|
var embeddingQuestion = $"""
|
||||||
Reformulate the following question taking into account the context of the chat to perform embeddings search:
|
Reformulate the following question taking into account the context of the chat to perform embeddings search:
|
||||||
@@ -28,34 +29,34 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
|
|||||||
|
|
||||||
chat.AddUserMessage(embeddingQuestion);
|
chat.AddUserMessage(embeddingQuestion);
|
||||||
|
|
||||||
var reformulatedQuestion = await chatCompletionService.GetChatMessageContentAsync(chat)!;
|
var reformulatedQuestion = await chatCompletionService.GetChatMessageContentAsync(chat, cancellationToken: cancellationToken);
|
||||||
chat.AddAssistantMessage(reformulatedQuestion.Content!);
|
chat.AddAssistantMessage(reformulatedQuestion.Content!);
|
||||||
|
|
||||||
await UpdateCacheAsync(conversationId, chat);
|
await UpdateCacheAsync(conversationId, chat, cancellationToken);
|
||||||
|
|
||||||
var tokenUsage = GetTokenUsage(reformulatedQuestion);
|
var tokenUsage = GetTokenUsage(reformulatedQuestion);
|
||||||
|
|
||||||
return new(reformulatedQuestion.Content!, tokenUsage);
|
return new(reformulatedQuestion.Content!, tokenUsage);
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<ChatResponse> AskQuestionAsync(Guid conversationId, IEnumerable<string> chunks, string question)
|
public async Task<ChatResponse> AskQuestionAsync(Guid conversationId, IEnumerable<string> chunks, string question, CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
var chat = CreateChatAsync(chunks, question);
|
var chat = CreateChatAsync(chunks, question);
|
||||||
|
|
||||||
var answer = await chatCompletionService.GetChatMessageContentAsync(chat, new AzureOpenAIPromptExecutionSettings
|
var answer = await chatCompletionService.GetChatMessageContentAsync(chat, new AzureOpenAIPromptExecutionSettings
|
||||||
{
|
{
|
||||||
MaxTokens = appSettings.MaxOutputTokens
|
MaxTokens = appSettings.MaxOutputTokens
|
||||||
});
|
}, cancellationToken: cancellationToken);
|
||||||
|
|
||||||
// Add question and answer to the chat history.
|
// Add question and answer to the chat history.
|
||||||
await SetChatHistoryAsync(conversationId, question, answer.Content!);
|
await SetChatHistoryAsync(conversationId, question, answer.Content!, cancellationToken);
|
||||||
|
|
||||||
var tokenUsage = GetTokenUsage(answer);
|
var tokenUsage = GetTokenUsage(answer);
|
||||||
|
|
||||||
return new(answer.Content!, tokenUsage);
|
return new(answer.Content!, tokenUsage);
|
||||||
}
|
}
|
||||||
|
|
||||||
public async IAsyncEnumerable<ChatResponse> AskStreamingAsync(Guid conversationId, IEnumerable<string> chunks, string question)
|
public async IAsyncEnumerable<ChatResponse> AskStreamingAsync(Guid conversationId, IEnumerable<string> chunks, string question, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
var chat = CreateChatAsync(chunks, question);
|
var chat = CreateChatAsync(chunks, question);
|
||||||
|
|
||||||
@@ -63,7 +64,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
|
|||||||
await foreach (var token in chatCompletionService.GetStreamingChatMessageContentsAsync(chat, new AzureOpenAIPromptExecutionSettings
|
await foreach (var token in chatCompletionService.GetStreamingChatMessageContentsAsync(chat, new AzureOpenAIPromptExecutionSettings
|
||||||
{
|
{
|
||||||
MaxTokens = appSettings.MaxOutputTokens
|
MaxTokens = appSettings.MaxOutputTokens
|
||||||
}))
|
}, cancellationToken: cancellationToken))
|
||||||
{
|
{
|
||||||
if (!string.IsNullOrEmpty(token.Content))
|
if (!string.IsNullOrEmpty(token.Content))
|
||||||
{
|
{
|
||||||
@@ -82,7 +83,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add question and answer to the chat history.
|
// Add question and answer to the chat history.
|
||||||
await SetChatHistoryAsync(conversationId, question, answer.ToString());
|
await SetChatHistoryAsync(conversationId, question, answer.ToString(), cancellationToken);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static TokenUsage? GetTokenUsage(Microsoft.SemanticKernel.ChatMessageContent message)
|
private static TokenUsage? GetTokenUsage(Microsoft.SemanticKernel.ChatMessageContent message)
|
||||||
@@ -153,28 +154,27 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
|
|||||||
return chat;
|
return chat;
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task UpdateCacheAsync(Guid conversationId, ChatHistory chat)
|
private async Task UpdateCacheAsync(Guid conversationId, ChatHistory chat, CancellationToken cancellationToken)
|
||||||
=> await cache.SetAsync(conversationId.ToString(), chat);
|
=> await cache.SetAsync(conversationId.ToString(), chat, cancellationToken: cancellationToken);
|
||||||
|
|
||||||
private async Task<ChatHistory> GetChatHistoryAsync(Guid conversationId)
|
private async Task<ChatHistory> GetChatHistoryAsync(Guid conversationId, CancellationToken cancellationToken)
|
||||||
{
|
{
|
||||||
var historyCache = await cache.GetOrCreateAsync(conversationId.ToString(),
|
var historyCache = await cache.GetOrCreateAsync(conversationId.ToString(), (cancellationToken) =>
|
||||||
(cancellationToken) =>
|
|
||||||
{
|
{
|
||||||
return ValueTask.FromResult<ChatHistory>([]);
|
return ValueTask.FromResult<ChatHistory>([]);
|
||||||
});
|
}, cancellationToken: cancellationToken);
|
||||||
|
|
||||||
var chat = new ChatHistory(historyCache);
|
var chat = new ChatHistory(historyCache);
|
||||||
return chat;
|
return chat;
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task SetChatHistoryAsync(Guid conversationId, string question, string answer)
|
private async Task SetChatHistoryAsync(Guid conversationId, string question, string answer, CancellationToken cancellationToken)
|
||||||
{
|
{
|
||||||
var history = await GetChatHistoryAsync(conversationId);
|
var history = await GetChatHistoryAsync(conversationId, cancellationToken);
|
||||||
|
|
||||||
history.AddUserMessage(question);
|
history.AddUserMessage(question);
|
||||||
history.AddAssistantMessage(answer);
|
history.AddAssistantMessage(answer);
|
||||||
|
|
||||||
await UpdateCacheAsync(conversationId, history);
|
await UpdateCacheAsync(conversationId, history, cancellationToken);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,33 +7,33 @@ namespace SqlDatabaseVectorSearch.Services;
|
|||||||
|
|
||||||
public class DocumentService(ApplicationDbContext dbContext)
|
public class DocumentService(ApplicationDbContext dbContext)
|
||||||
{
|
{
|
||||||
public async Task<IEnumerable<Document>> GetDocumentsAsync()
|
public async Task<IEnumerable<Document>> GetDocumentsAsync(CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
var documents = await dbContext.Documents.OrderBy(d => d.Name)
|
var documents = await dbContext.Documents.OrderBy(d => d.Name)
|
||||||
.Select(d => new Document(d.Id, d.Name, d.CreationDate, d.Chunks.Count))
|
.Select(d => new Document(d.Id, d.Name, d.CreationDate, d.Chunks.Count))
|
||||||
.ToListAsync();
|
.ToListAsync(cancellationToken: cancellationToken);
|
||||||
|
|
||||||
return documents;
|
return documents;
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<IEnumerable<DocumentChunk>> GetDocumentChunksAsync(Guid documentId)
|
public async Task<IEnumerable<DocumentChunk>> GetDocumentChunksAsync(Guid documentId, CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
var documentChunks = await dbContext.DocumentChunks.Where(c => c.DocumentId == documentId).OrderBy(c => c.Index)
|
var documentChunks = await dbContext.DocumentChunks.Where(c => c.DocumentId == documentId).OrderBy(c => c.Index)
|
||||||
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, null))
|
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, null))
|
||||||
.ToListAsync();
|
.ToListAsync(cancellationToken: cancellationToken);
|
||||||
|
|
||||||
return documentChunks;
|
return documentChunks;
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<DocumentChunk?> GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId)
|
public async Task<DocumentChunk?> GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId, CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
var documentChunk = await dbContext.DocumentChunks.Where(c => c.Id == documentChunkId && c.DocumentId == documentId)
|
var documentChunk = await dbContext.DocumentChunks.Where(c => c.Id == documentChunkId && c.DocumentId == documentId)
|
||||||
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, c.Embedding))
|
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, c.Embedding))
|
||||||
.FirstOrDefaultAsync();
|
.FirstOrDefaultAsync(cancellationToken: cancellationToken);
|
||||||
|
|
||||||
return documentChunk;
|
return documentChunk;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Task DeleteDocumentAsync(Guid documentId)
|
public Task DeleteDocumentAsync(Guid documentId, CancellationToken cancellationToken = default)
|
||||||
=> dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync();
|
=> dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync(cancellationToken: cancellationToken);
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,6 @@
|
|||||||
using System.Data;
|
using System.Data;
|
||||||
|
using System.Runtime.CompilerServices;
|
||||||
|
using System.Threading;
|
||||||
using Microsoft.EntityFrameworkCore;
|
using Microsoft.EntityFrameworkCore;
|
||||||
using Microsoft.Extensions.Options;
|
using Microsoft.Extensions.Options;
|
||||||
using Microsoft.SemanticKernel.Embeddings;
|
using Microsoft.SemanticKernel.Embeddings;
|
||||||
@@ -14,24 +16,24 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
|||||||
{
|
{
|
||||||
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
private readonly AppSettings appSettings = appSettingsOptions.Value;
|
||||||
|
|
||||||
public async Task<ImportDocumentResponse> ImportAsync(Stream stream, string name, string contentType, Guid? documentId)
|
public async Task<ImportDocumentResponse> ImportAsync(Stream stream, string name, string contentType, Guid? documentId, CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
// Extract the contents of the file.
|
// Extract the contents of the file.
|
||||||
var decoder = serviceProvider.GetKeyedService<IContentDecoder>(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported.");
|
var decoder = serviceProvider.GetKeyedService<IContentDecoder>(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported.");
|
||||||
var content = await decoder.DecodeAsync(stream, contentType);
|
var content = await decoder.DecodeAsync(stream, contentType, cancellationToken);
|
||||||
|
|
||||||
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
|
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
|
||||||
var tokenCount = tokenizerService.CountEmbeddingTokens(content);
|
var tokenCount = tokenizerService.CountEmbeddingTokens(content);
|
||||||
|
|
||||||
var strategy = dbContext.Database.CreateExecutionStrategy();
|
var strategy = dbContext.Database.CreateExecutionStrategy();
|
||||||
var document = await strategy.ExecuteAsync(async () =>
|
var document = await strategy.ExecuteAsync(async (cancellationToken) =>
|
||||||
{
|
{
|
||||||
await dbContext.Database.BeginTransactionAsync();
|
await dbContext.Database.BeginTransactionAsync();
|
||||||
|
|
||||||
if (documentId.HasValue)
|
if (documentId.HasValue)
|
||||||
{
|
{
|
||||||
// If the user is importing a document that already exists, delete the previous one.
|
// If the user is importing a document that already exists, delete the previous one.
|
||||||
await documentService.DeleteDocumentAsync(documentId.Value);
|
await documentService.DeleteDocumentAsync(documentId.Value, cancellationToken);
|
||||||
}
|
}
|
||||||
|
|
||||||
var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() };
|
var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() };
|
||||||
@@ -39,7 +41,7 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
|||||||
|
|
||||||
// Split the content into chunks and generate the embeddings for each one.
|
// Split the content into chunks and generate the embeddings for each one.
|
||||||
var paragraphs = textChunkerService.Split(content);
|
var paragraphs = textChunkerService.Split(content);
|
||||||
var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(paragraphs);
|
var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(paragraphs, cancellationToken: cancellationToken);
|
||||||
|
|
||||||
// Save the document chunks and the corresponding embedding in the database.
|
// Save the document chunks and the corresponding embedding in the database.
|
||||||
foreach (var (index, paragraph) in paragraphs.Index())
|
foreach (var (index, paragraph) in paragraphs.Index())
|
||||||
@@ -50,31 +52,31 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
|||||||
dbContext.DocumentChunks.Add(documentChunk);
|
dbContext.DocumentChunks.Add(documentChunk);
|
||||||
}
|
}
|
||||||
|
|
||||||
await dbContext.SaveChangesAsync();
|
await dbContext.SaveChangesAsync(cancellationToken);
|
||||||
await dbContext.Database.CommitTransactionAsync();
|
await dbContext.Database.CommitTransactionAsync(cancellationToken);
|
||||||
|
|
||||||
return document;
|
return document;
|
||||||
});
|
}, cancellationToken);
|
||||||
|
|
||||||
return new(document.Id, tokenCount);
|
return new(document.Id, tokenCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<QuestionResponse> AskQuestionAsync(Question question, bool reformulate = true)
|
public async Task<QuestionResponse> AskQuestionAsync(Question question, bool reformulate = true, CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
|
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
|
||||||
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate);
|
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken);
|
||||||
|
|
||||||
var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
|
var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!, cancellationToken);
|
||||||
|
|
||||||
return new(question.Text, reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage));
|
return new(question.Text, reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage));
|
||||||
}
|
}
|
||||||
|
|
||||||
public async IAsyncEnumerable<QuestionResponse> AskStreamingAsync(Question question, bool reformulate = true)
|
public async IAsyncEnumerable<QuestionResponse> AskStreamingAsync(Question question, bool reformulate = true, [EnumeratorCancellation] CancellationToken cancellationToken = default)
|
||||||
{
|
{
|
||||||
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
|
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
|
||||||
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate);
|
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken);
|
||||||
|
|
||||||
var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
|
var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!, cancellationToken: cancellationToken);
|
||||||
|
|
||||||
// The first message contains the question and the corresponding token usage (if reformulated).
|
// The first message contains the question and the corresponding token usage (if reformulated).
|
||||||
yield return new(question.Text, reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null));
|
yield return new(question.Text, reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null));
|
||||||
@@ -96,20 +98,20 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate = true)
|
private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate, CancellationToken cancellationToken)
|
||||||
{
|
{
|
||||||
// Reformulate the question taking into account the context of the chat to perform keyword search and embeddings.
|
// Reformulate the question taking into account the context of the chat to perform keyword search and embeddings.
|
||||||
var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : new(question.Text);
|
var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text, cancellationToken) : new(question.Text);
|
||||||
var embeddingTokenCount = tokenizerService.CountEmbeddingTokens(reformulatedQuestion.Text!);
|
var embeddingTokenCount = tokenizerService.CountEmbeddingTokens(reformulatedQuestion.Text!);
|
||||||
|
|
||||||
// Perform Vector Search on SQL Database.
|
// Perform Vector Search on SQL Database.
|
||||||
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!);
|
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!, cancellationToken: cancellationToken);
|
||||||
|
|
||||||
var chunks = await dbContext.DocumentChunks
|
var chunks = await dbContext.DocumentChunks
|
||||||
.OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray()))
|
.OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray()))
|
||||||
.Select(c => c.Content)
|
.Select(c => c.Content)
|
||||||
.Take(appSettings.MaxRelevantChunks)
|
.Take(appSettings.MaxRelevantChunks)
|
||||||
.ToListAsync();
|
.ToListAsync(cancellationToken);
|
||||||
|
|
||||||
return (reformulatedQuestion, embeddingTokenCount, chunks);
|
return (reformulatedQuestion, embeddingTokenCount, chunks);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user