Merge pull request #8 from marcominerva/token_usage

Token usage
This commit is contained in:
Marco Minerva
2025-02-07 12:04:13 +01:00
committed by GitHub
14 changed files with 441 additions and 152 deletions
@@ -0,0 +1,3 @@
namespace SqlDatabaseVectorSearch.Models;
public record class ChatResponse(string? Text, TokenUsage? TokenUsage = null);
@@ -0,0 +1,3 @@
namespace SqlDatabaseVectorSearch.Models;
public record class ImportDocumentResponse(Guid DocumentId, int EmbeddingTokenCount);
@@ -0,0 +1,10 @@
namespace SqlDatabaseVectorSearch.Models;
// Question and Answer can be null when using response streaming.
public record class QuestionResponse(string? OriginalQuestion, string? ReformulatedQuestion, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null)
{
public QuestionResponse(string? token, StreamState streamState, TokenUsageResponse? tokenUsageResponse = null)
: this(null, null, token, streamState, tokenUsageResponse)
{
}
}
@@ -1,4 +0,0 @@
namespace SqlDatabaseVectorSearch.Models;
// Question and Asnwer can be null when using response streaming.
public record class Response(string? Question, string? Answer, StreamState? StreamState = null);
@@ -0,0 +1,6 @@
namespace SqlDatabaseVectorSearch.Models;
public record class TokenUsage(int InputTokenCount, int OutputTokenCount)
{
public int TotalTokenCount => InputTokenCount + OutputTokenCount;
}
@@ -0,0 +1,9 @@
namespace SqlDatabaseVectorSearch.Models;
public record class TokenUsageResponse(TokenUsage? Reformulation, int? EmbeddingTokenCount, TokenUsage? Question)
{
public TokenUsageResponse(TokenUsage? question)
: this(null, null, question)
{
}
}
@@ -1,3 +0,0 @@
namespace SqlDatabaseVectorSearch.Models;
public record class UploadDocumentResponse(Guid DocumentId);
+61 -58
View File
@@ -19,9 +19,14 @@ builder.Configuration.AddJsonFile("appsettings.local.json", optional: true, relo
var aiSettings = builder.Services.ConfigureAndGet<AzureOpenAISettings>(builder.Configuration, "AzureOpenAI")!;
var appSettings = builder.Services.ConfigureAndGet<AppSettings>(builder.Configuration, nameof(AppSettings))!;
builder.Services.ConfigureHttpJsonOptions(options =>
{
options.SerializerOptions.Converters.Add(new JsonStringEnumConverter());
});
builder.Services.AddSingleton(TimeProvider.System);
builder.Services.AddSqlServer<ApplicationDbContext>(builder.Configuration.GetConnectionString("SqlConnection"), options =>
builder.Services.AddAzureSql<ApplicationDbContext>(builder.Configuration.GetConnectionString("SqlConnection"), options =>
{
options.UseVectorSearch();
}, options =>
@@ -48,19 +53,17 @@ builder.Services.AddKernel()
.AddAzureOpenAITextEmbeddingGeneration(aiSettings.Embedding.Deployment, aiSettings.Embedding.Endpoint, aiSettings.Embedding.ApiKey, dimensions: aiSettings.Embedding.Dimensions)
.AddAzureOpenAIChatCompletion(aiSettings.ChatCompletion.Deployment, aiSettings.ChatCompletion.Endpoint, aiSettings.ChatCompletion.ApiKey);
builder.Services.AddSingleton<TextChunkerService>();
builder.Services.AddSingleton<TokenizerService>();
builder.Services.AddSingleton<ChatService>();
builder.Services.AddScoped<VectorSearchService>();
builder.Services.AddScoped<DocumentService>();
builder.Services.AddKeyedSingleton<IContentDecoder, PdfContentDecoder>(MediaTypeNames.Application.Pdf);
builder.Services.AddKeyedSingleton<IContentDecoder, DocxContentDecoder>("application/vnd.openxmlformats-officedocument.wordprocessingml.document");
builder.Services.AddKeyedSingleton<IContentDecoder, TextContentDecoder>(MediaTypeNames.Text.Plain);
builder.Services.ConfigureHttpJsonOptions(options =>
{
options.SerializerOptions.Converters.Add(new JsonStringEnumConverter());
});
builder.Services.AddOpenApi(options =>
{
options.RemoveServerList();
@@ -93,57 +96,6 @@ app.UseSwaggerUI(options =>
options.SwaggerEndpoint("/openapi/v1.json", builder.Environment.ApplicationName);
});
var documentsApiGroup = app.MapGroup("/api/documents").WithTags("Documents");
documentsApiGroup.MapGet(string.Empty, async (VectorSearchService vectorSearchService) =>
{
var documents = await vectorSearchService.GetDocumentsAsync();
return TypedResults.Ok(documents);
})
.WithSummary("Gets the list of documents");
documentsApiGroup.MapGet("{documentId:guid}/chunks", async (Guid documentId, VectorSearchService vectorSearchService) =>
{
var documents = await vectorSearchService.GetDocumentChunksAsync(documentId);
return TypedResults.Ok(documents);
})
.WithSummary("Gets the list of chunks of a given document")
.WithDescription("The list does not contain embedding. Use '/api/documents/{documentId}/chunks/{documentChunkId}' to get the embedding for a given chunk.");
documentsApiGroup.MapGet("{documentId:guid}/chunks/{documentChunkId:guid}", async Task<Results<Ok<DocumentChunk>, NotFound>> (Guid documentId, Guid documentChunkId, VectorSearchService vectorSearchService) =>
{
var chunk = await vectorSearchService.GetDocumentChunkEmbeddingAsync(documentId, documentChunkId);
if (chunk is null)
{
return TypedResults.NotFound();
}
return TypedResults.Ok(chunk);
})
.ProducesProblem(StatusCodes.Status404NotFound)
.WithSummary("Gets the details of a given chunk, includings its embedding");
documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchService vectorSearchService,
[Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) =>
{
using var stream = file.OpenReadStream();
documentId = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId);
return TypedResults.Ok(new UploadDocumentResponse(documentId.Value));
})
.DisableAntiforgery()
.ProducesProblem(StatusCodes.Status400BadRequest)
.WithSummary("Uploads a document")
.WithDescription("Uploads a document to SQL Database and saves its embedding using the native VECTOR type. The document will be indexed and used to answer questions. Currently, PDF, DOCX and TXT files are supported.");
documentsApiGroup.MapDelete("{documentId:guid}", async (Guid documentId, VectorSearchService vectorSearchService) =>
{
await vectorSearchService.DeleteDocumentAsync(documentId);
return TypedResults.NoContent();
})
.WithSummary("Deletes a document")
.WithDescription("This endpoint deletes the document and all its chunks.");
app.MapPost("/api/ask", async (Question question, VectorSearchService vectorSearchService,
[Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) =>
{
@@ -157,7 +109,7 @@ app.MapPost("/api/ask", async (Question question, VectorSearchService vectorSear
app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vectorSearchService,
[Description("If true, the question will be reformulated taking into account the context of the chat identified by the given ConversationId.")] bool reformulate = true) =>
{
async IAsyncEnumerable<Response> Stream()
async IAsyncEnumerable<QuestionResponse> Stream()
{
// Requests a streaming response.
var responseStream = vectorSearchService.AskStreamingAsync(question, reformulate);
@@ -174,4 +126,55 @@ app.MapPost("/api/ask-streaming", (Question question, VectorSearchService vector
.WithDescription("The question will be reformulated taking into account the context of the chat identified by the given ConversationId.")
.WithTags("Ask");
var documentsApiGroup = app.MapGroup("/api/documents").WithTags("Documents");
documentsApiGroup.MapGet(string.Empty, async (DocumentService documentService) =>
{
var documents = await documentService.GetDocumentsAsync();
return TypedResults.Ok(documents);
})
.WithSummary("Gets the list of documents");
documentsApiGroup.MapGet("{documentId:guid}/chunks", async (Guid documentId, DocumentService documentService) =>
{
var documents = await documentService.GetDocumentChunksAsync(documentId);
return TypedResults.Ok(documents);
})
.WithSummary("Gets the list of chunks of a given document")
.WithDescription("The list does not contain embedding. Use '/api/documents/{documentId}/chunks/{documentChunkId}' to get the embedding for a given chunk.");
documentsApiGroup.MapGet("{documentId:guid}/chunks/{documentChunkId:guid}", async Task<Results<Ok<DocumentChunk>, NotFound>> (Guid documentId, Guid documentChunkId, DocumentService documentService) =>
{
var chunk = await documentService.GetDocumentChunkEmbeddingAsync(documentId, documentChunkId);
if (chunk is null)
{
return TypedResults.NotFound();
}
return TypedResults.Ok(chunk);
})
.ProducesProblem(StatusCodes.Status404NotFound)
.WithSummary("Gets the details of a given chunk, includings its embedding");
documentsApiGroup.MapDelete("{documentId:guid}", async (Guid documentId, DocumentService documentService) =>
{
await documentService.DeleteDocumentAsync(documentId);
return TypedResults.NoContent();
})
.WithSummary("Deletes a document")
.WithDescription("This endpoint deletes the document and all its chunks.");
documentsApiGroup.MapPost(string.Empty, async (IFormFile file, VectorSearchService vectorSearchService,
[Description("The unique identifier of the document. If not provided, a new one will be generated. If you specify an existing documentId, the corresponding document will be overwritten.")] Guid? documentId = null) =>
{
using var stream = file.OpenReadStream();
var response = await vectorSearchService.ImportAsync(stream, file.FileName, file.ContentType, documentId);
return TypedResults.Ok(response);
})
.DisableAntiforgery()
.ProducesProblem(StatusCodes.Status400BadRequest)
.WithSummary("Uploads a document")
.WithDescription("Uploads a document to SQL Database and saves its embedding using the native VECTOR type. The document will be indexed and used to answer questions. Currently, PDF, DOCX and TXT files are supported.");
app.Run();
+46 -11
View File
@@ -3,6 +3,8 @@ using Microsoft.Extensions.Caching.Hybrid;
using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.AzureOpenAI;
using OpenAI.Chat;
using SqlDatabaseVectorSearch.Models;
using SqlDatabaseVectorSearch.Settings;
namespace SqlDatabaseVectorSearch.Services;
@@ -11,7 +13,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
{
private readonly AppSettings appSettings = appSettingsOptions.Value;
public async Task<string> CreateQuestionAsync(Guid conversationId, string question)
public async Task<ChatResponse> CreateQuestionAsync(Guid conversationId, string question)
{
var chat = await GetChatHistoryAsync(conversationId);
@@ -20,7 +22,7 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
---
{question}
---
You must reformulate the question in the same language of the user's question.
You must reformulate the question in the same language of the user's question. For example, it the user asks a question in English, the answer must be in English.
Never add "in this chat", "in the context of this chat", "in the context of our conversation", "search for" or something like that in your answer.
""";
@@ -31,10 +33,12 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
await UpdateCacheAsync(conversationId, chat);
return reformulatedQuestion.Content!;
var tokenUsage = GetTokenUsage(reformulatedQuestion);
return new(reformulatedQuestion.Content!, tokenUsage);
}
public async Task<string> AskQuestionAsync(Guid conversationId, IEnumerable<string> chunks, string question)
public async Task<ChatResponse> AskQuestionAsync(Guid conversationId, IEnumerable<string> chunks, string question)
{
var chat = CreateChatAsync(chunks, question);
@@ -46,10 +50,12 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
// Add question and answer to the chat history.
await SetChatHistoryAsync(conversationId, question, answer.Content!);
return answer.Content!;
var tokenUsage = GetTokenUsage(answer);
return new(answer.Content!, tokenUsage);
}
public async IAsyncEnumerable<string> AskStreamingAsync(Guid conversationId, IEnumerable<string> chunks, string question)
public async IAsyncEnumerable<ChatResponse> AskStreamingAsync(Guid conversationId, IEnumerable<string> chunks, string question)
{
var chat = CreateChatAsync(chunks, question);
@@ -61,22 +67,51 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
{
if (!string.IsNullOrEmpty(token.Content))
{
yield return token.Content;
yield return new(token.Content);
answer.Append(token.Content);
}
else if (token.Content is null)
{
// Token usage is returned in the last message, when the Content is null.
var tokenUsage = GetTokenUsage(token);
if (tokenUsage is not null)
{
yield return new(null, tokenUsage);
}
}
}
// Add question and answer to the chat history.
await SetChatHistoryAsync(conversationId, question, answer.ToString());
}
private static TokenUsage? GetTokenUsage(Microsoft.SemanticKernel.ChatMessageContent message)
{
if (message.InnerContent is ChatCompletion content && content.Usage is not null)
{
return new(content.Usage.InputTokenCount, content.Usage.OutputTokenCount);
}
return null;
}
private static TokenUsage? GetTokenUsage(Microsoft.SemanticKernel.StreamingChatMessageContent message)
{
if (message.InnerContent is StreamingChatCompletionUpdate content && content.Usage is not null)
{
return new(content.Usage.InputTokenCount, content.Usage.OutputTokenCount);
}
return null;
}
private ChatHistory CreateChatAsync(IEnumerable<string> chunks, string question)
{
var chat = new ChatHistory("""
You can use only the information provided in this chat to answer questions. If you don't know the answer, reply suggesting to refine the question.
For example, if the user asks "What is the capital of France?" and in this chat there isn't information about France, you should reply something like "This information isn't available in the given context".
Never answer to questions that are not related to this chat.
You must answer in the same language of the user's question.
You must answer in the same language of the user's question. For example, it the user asks a question in English, the answer must be in English.
""");
var prompt = new StringBuilder($"""
@@ -89,15 +124,15 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
""");
var tokensAvailable = appSettings.MaxInputTokens
- tokenizerService.CountTokens(chat[0].ToString()) // System prompt.
- tokenizerService.CountTokens(prompt.ToString()) // Initial user prompt.
- tokenizerService.CountChatCompletionTokens(chat[0].ToString()) // System prompt.
- tokenizerService.CountChatCompletionTokens(prompt.ToString()) // Initial user prompt.
- appSettings.MaxOutputTokens; // To ensure there is enough space for the answer.
foreach (var chunk in chunks)
{
var text = $"---{Environment.NewLine}{chunk}";
var tokenCount = tokenizerService.CountTokens(text);
var tokenCount = tokenizerService.CountChatCompletionTokens(text);
if (tokenCount > tokensAvailable)
{
// There isn't enough space to add the current chunk.
@@ -0,0 +1,39 @@
using System.Data;
using Microsoft.EntityFrameworkCore;
using SqlDatabaseVectorSearch.DataAccessLayer;
using SqlDatabaseVectorSearch.Models;
namespace SqlDatabaseVectorSearch.Services;
public class DocumentService(ApplicationDbContext dbContext)
{
public async Task<IEnumerable<Document>> GetDocumentsAsync()
{
var documents = await dbContext.Documents.OrderBy(d => d.Name)
.Select(d => new Document(d.Id, d.Name, d.CreationDate, d.Chunks.Count))
.ToListAsync();
return documents;
}
public async Task<IEnumerable<DocumentChunk>> GetDocumentChunksAsync(Guid documentId)
{
var documentChunks = await dbContext.DocumentChunks.Where(c => c.DocumentId == documentId).OrderBy(c => c.Index)
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, null))
.ToListAsync();
return documentChunks;
}
public async Task<DocumentChunk?> GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId)
{
var documentChunk = await dbContext.DocumentChunks.Where(c => c.Id == documentChunkId && c.DocumentId == documentId)
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, c.Embedding))
.FirstOrDefaultAsync();
return documentChunk;
}
public Task DeleteDocumentAsync(Guid documentId)
=> dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync();
}
@@ -0,0 +1,18 @@
using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel.Text;
using SqlDatabaseVectorSearch.Settings;
namespace SqlDatabaseVectorSearch.Services;
public class TextChunkerService(TokenizerService tokenizerService, IOptions<AppSettings> appSettingsOptions)
{
private readonly AppSettings appSettings = appSettingsOptions.Value;
public IList<string> Split(string text)
{
var lines = TextChunker.SplitPlainTextLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountChatCompletionTokens);
var paragraphs = TextChunker.SplitPlainTextParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountChatCompletionTokens);
return paragraphs;
}
}
@@ -6,8 +6,13 @@ namespace SqlDatabaseVectorSearch.Services;
public class TokenizerService(IOptions<AzureOpenAISettings> settingsOptions)
{
private readonly TiktokenTokenizer tokenizer = TiktokenTokenizer.CreateForModel(settingsOptions.Value.ChatCompletion.ModelId);
private readonly TiktokenTokenizer chatCompletiontokenizer = TiktokenTokenizer.CreateForModel(settingsOptions.Value.ChatCompletion.ModelId);
public int CountTokens(string input)
=> tokenizer.CountTokens(input);
private readonly TiktokenTokenizer embeddingTokenizer = TiktokenTokenizer.CreateForModel(settingsOptions.Value.Embedding.ModelId);
public int CountChatCompletionTokens(string input)
=> chatCompletiontokenizer.CountTokens(input);
public int CountEmbeddingTokens(string input)
=> embeddingTokenizer.CountTokens(input);
}
@@ -2,7 +2,6 @@
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel.Embeddings;
using Microsoft.SemanticKernel.Text;
using SqlDatabaseVectorSearch.ContentDecoders;
using SqlDatabaseVectorSearch.DataAccessLayer;
using SqlDatabaseVectorSearch.Models;
@@ -11,111 +10,100 @@ using Entities = SqlDatabaseVectorSearch.DataAccessLayer.Entities;
namespace SqlDatabaseVectorSearch.Services;
public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDbContext dbContext, ITextEmbeddingGenerationService textEmbeddingGenerationService, ChatService chatService, TokenizerService tokenizerService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions, ILogger<VectorSearchService> logger)
public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDbContext dbContext, DocumentService documentService, ITextEmbeddingGenerationService textEmbeddingGenerationService, TokenizerService tokenizerService, TextChunkerService textChunkerService, ChatService chatService, TimeProvider timeProvider, IOptions<AppSettings> appSettingsOptions, ILogger<VectorSearchService> logger)
{
private readonly AppSettings appSettings = appSettingsOptions.Value;
public async Task<Guid> ImportAsync(Stream stream, string name, string contentType, Guid? documentId)
public async Task<ImportDocumentResponse> ImportAsync(Stream stream, string name, string contentType, Guid? documentId)
{
// Extract the contents of the file.
var decoder = serviceProvider.GetKeyedService<IContentDecoder>(contentType) ?? throw new NotSupportedException($"Content type '{contentType}' is not supported.");
var content = await decoder.DecodeAsync(stream, contentType);
await dbContext.Database.BeginTransactionAsync();
// We get the token count of the whole document because it is the total number of token used by embedding (it may be necessary, for example, for cost analysis).
var tokenCount = tokenizerService.CountEmbeddingTokens(content);
if (documentId.HasValue)
var strategy = dbContext.Database.CreateExecutionStrategy();
var document = await strategy.ExecuteAsync(async () =>
{
// If the user is importing a document that already exists, delete the previous one.
await DeleteDocumentAsync(documentId.Value);
}
await dbContext.Database.BeginTransactionAsync();
var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() };
dbContext.Documents.Add(document);
if (documentId.HasValue)
{
// If the user is importing a document that already exists, delete the previous one.
await documentService.DeleteDocumentAsync(documentId.Value);
}
// Split the content into chunks and generate the embeddings for each one.
var lines = TextChunker.SplitPlainTextLines(content, appSettings.MaxTokensPerLine, tokenizerService.CountTokens);
var paragraphs = TextChunker.SplitPlainTextParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountTokens);
var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(paragraphs);
var document = new Entities.Document { Id = documentId.GetValueOrDefault(), Name = name, CreationDate = timeProvider.GetUtcNow() };
dbContext.Documents.Add(document);
// Save the document chunks and the corresponding embedding in the database.
foreach (var (index, paragraph) in paragraphs.Index())
{
logger.LogInformation("Storing a paragraph of {TokenCount} tokens.", tokenizerService.CountTokens(paragraph));
// Split the content into chunks and generate the embeddings for each one.
var paragraphs = textChunkerService.Split(content);
var embeddings = await textEmbeddingGenerationService.GenerateEmbeddingsAsync(paragraphs);
var documentChunk = new Entities.DocumentChunk { Document = document, Index = index, Content = paragraph!, Embedding = embeddings[index].ToArray() };
dbContext.DocumentChunks.Add(documentChunk);
}
// Save the document chunks and the corresponding embedding in the database.
foreach (var (index, paragraph) in paragraphs.Index())
{
logger.LogInformation("Storing a paragraph of {TokenCount} tokens.", tokenizerService.CountChatCompletionTokens(paragraph));
await dbContext.SaveChangesAsync();
await dbContext.Database.CommitTransactionAsync();
var documentChunk = new Entities.DocumentChunk { Document = document, Index = index, Content = paragraph!, Embedding = embeddings[index].ToArray() };
dbContext.DocumentChunks.Add(documentChunk);
}
return document.Id;
await dbContext.SaveChangesAsync();
await dbContext.Database.CommitTransactionAsync();
return document;
});
return new(document.Id, tokenCount);
}
public async Task<IEnumerable<Document>> GetDocumentsAsync()
public async Task<QuestionResponse> AskQuestionAsync(Question question, bool reformulate = true)
{
var documents = await dbContext.Documents.OrderBy(d => d.Name)
.Select(d => new Document(d.Id, d.Name, d.CreationDate, d.Chunks.Count))
.ToListAsync();
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate);
return documents;
var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
return new(question.Text, reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage));
}
public async Task<IEnumerable<DocumentChunk>> GetDocumentChunksAsync(Guid documentId)
public async IAsyncEnumerable<QuestionResponse> AskStreamingAsync(Question question, bool reformulate = true)
{
var documentChunks = await dbContext.DocumentChunks.Where(c => c.DocumentId == documentId).OrderBy(c => c.Index)
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, null))
.ToListAsync();
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate);
return documentChunks;
}
var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion.Text!);
public async Task<DocumentChunk?> GetDocumentChunkEmbeddingAsync(Guid documentId, Guid documentChunkId)
{
var documentChunk = await dbContext.DocumentChunks.Where(c => c.Id == documentChunkId && c.DocumentId == documentId)
.Select(c => new DocumentChunk(c.Id, c.Index, c.Content, c.Embedding))
.FirstOrDefaultAsync();
// The first message contains the question and the corresponding token usage (if reformulated).
yield return new(question.Text, reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null));
return documentChunk;
}
public Task DeleteDocumentAsync(Guid documentId)
=> dbContext.Documents.Where(d => d.Id == documentId).ExecuteDeleteAsync();
public async Task<Response> AskQuestionAsync(Question question, bool reformulate = true)
{
var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate);
var answer = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion);
return new Response(reformulatedQuestion, answer);
}
public async IAsyncEnumerable<Response> AskStreamingAsync(Question question, bool reformulate = true)
{
var (reformulatedQuestion, chunks) = await CreateContextAsync(question, reformulate);
var answerStream = chatService.AskStreamingAsync(question.ConversationId, chunks, reformulatedQuestion);
// The first message contains the original question.
yield return new Response(reformulatedQuestion, null, StreamState.Start);
TokenUsageResponse? tokenUsageResponse = null;
// Return each token as a partial response.
await foreach (var token in answerStream)
await foreach (var (token, tokenUsage) in answerStream)
{
yield return new Response(null, token, StreamState.Append);
// Token usage is expected in the last message.
tokenUsageResponse = tokenUsage is not null ? new(tokenUsage) : null;
yield return new(token, tokenUsageResponse is null ? StreamState.Append : StreamState.End, tokenUsageResponse);
}
// The last message tells the client that the stream has ended.
yield return new Response(null, null, StreamState.End);
// If the token usage has not been returned in the last message, we must explicitly tells that the stream is ended.
if (tokenUsageResponse is null)
{
yield return new(null, StreamState.End);
}
}
private async Task<(string Question, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate = true)
private async Task<(ChatResponse ReformulatedQuestion, int EmbeddingTokenCount, IEnumerable<string> Chunks)> CreateContextAsync(Question question, bool reformulate = true)
{
// Reformulate the following question taking into account the context of the chat to perform keyword search and embeddings:
var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : question.Text;
// Reformulate the question taking into account the context of the chat to perform keyword search and embeddings.
var reformulatedQuestion = reformulate ? await chatService.CreateQuestionAsync(question.ConversationId, question.Text) : new(question.Text);
var embeddingTokenCount = tokenizerService.CountEmbeddingTokens(reformulatedQuestion.Text!);
// Perform Vector Search on SQL Database.
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion);
var questionEmbedding = await textEmbeddingGenerationService.GenerateEmbeddingAsync(reformulatedQuestion.Text!);
var chunks = await dbContext.DocumentChunks
.OrderBy(c => EF.Functions.VectorDistance("cosine", c.Embedding, questionEmbedding.ToArray()))
@@ -123,6 +111,6 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
.Take(appSettings.MaxRelevantChunks)
.ToListAsync();
return (reformulatedQuestion, chunks);
return (reformulatedQuestion, embeddingTokenCount, chunks);
}
}