Enhance citation handling and formatting

Updated citation management in the application by removing the `RemoveCitations` and `ExtractCitations` methods in `Ask.razor`, and directly processing citations from the `delta` object. The `Response` class now includes a `Citations` property for better data handling.

Modified `VectorSearchService.cs` to extract citations from the full answer in `AskQuestionAsync` and return them at the end of the streaming process in `AskStreamingAsync`.

Introduced a new `Citation` class in `Citation.cs` to encapsulate citation properties, ensuring structured management of citation data.

Updated citation formatting rules to enforce a specific XML format, ensuring citations are presented at the end of responses rather than within the answer text.
This commit is contained in:
Marco Minerva
2025-06-10 11:50:51 +02:00
parent cdf8356e11
commit 4571478787
5 changed files with 120 additions and 72 deletions
@@ -197,16 +197,20 @@
else if (delta.StreamState == StreamState.Append) else if (delta.StreamState == StreamState.Append)
{ {
// Adds tokens to the assistant message as they are received. // Adds tokens to the assistant message as they are received.
assistantMessage.RawText += delta.Answer; assistantMessage.Text += delta.Answer;
// Updates the Text property to remove citations, if any.
assistantMessage.Text = RemoveCitations(assistantMessage.RawText);
} }
else if (delta.StreamState == StreamState.End) else if (delta.StreamState == StreamState.End)
{ {
// Extracts citations, if any. // Get citations from the response.
var (_, citations) = ExtractCitations(assistantMessage.RawText); assistantMessage.Citations = delta.Citations?.Select(c => new Citation
assistantMessage.Citations = citations; {
DocumentId = c.DocumentId,
ChunkId = c.ChunkId,
FileName = c.FileName,
Quote = c.Quote,
PageNumber = c.PageNumber,
IndexOnPage = c.IndexOnPage
});
assistantMessage.IsCompleted = true; assistantMessage.IsCompleted = true;
assistantMessage.TokenUsage += FormatTokenUsage(delta.TokenUsage); assistantMessage.TokenUsage += FormatTokenUsage(delta.TokenUsage);
@@ -295,58 +299,8 @@
await JSRuntime.InvokeVoidAsync("scrollTo", chat); await JSRuntime.InvokeVoidAsync("scrollTo", chat);
} }
private static string RemoveCitations(string? text)
{
if (string.IsNullOrEmpty(text))
{
return string.Empty;
}
return (text.AsSpan().IndexOf("<cit", StringComparison.OrdinalIgnoreCase) is var index and >= 0 ? text[..index] : text).TrimEnd();
}
private static (string, IEnumerable<Citation>) ExtractCitations(string? text)
{
var citations = new List<Citation>();
if (string.IsNullOrEmpty(text))
{
return (text ?? string.Empty, citations);
}
var pattern = @"<citation\s+document-id='(?<documentId>[^']*)'\s+chunk-id='(?<chunkId>[^']*)'\s+filename='(?<filename>[^']*)'\s+page-number='(?<pageNumber>[^']*)'\s+index-on-page='(?<indexOnPage>[^']*)'>\s*(?<quote>.*?)\s*</citation>";
var matches = Regex.Matches(text, pattern, RegexOptions.Singleline);
foreach (Match match in matches)
{
if (match.Success)
{
citations.Add(new Citation
{
DocumentId = Guid.Parse(match.Groups["documentId"].Value),
ChunkId = Guid.Parse(match.Groups["chunkId"].Value),
FileName = match.Groups["filename"].Value,
PageNumber = int.TryParse(match.Groups["pageNumber"].Value, out var pageNumber) && pageNumber > 0 ? pageNumber : null,
IndexOnPage = int.TryParse(match.Groups["indexOnPage"].Value, out var indexOnPage) ? indexOnPage : 0,
Quote = match.Groups["quote"].Value
});
}
}
// Remove all <citation> tags from the text
var cleanText = Regex.Replace(text, pattern, string.Empty, RegexOptions.Singleline).TrimEnd();
return (cleanText, citations);
}
public class Message public class Message
{ {
private string? rawText;
public string? RawText
{
get => rawText ?? Text;
set => rawText = value;
}
public string? Text { get; set; } public string? Text { get; set; }
public required string Role { get; set; } public required string Role { get; set; }
@@ -0,0 +1,16 @@
namespace SqlDatabaseVectorSearch.Models;
public class Citation
{
public Guid DocumentId { get; set; }
public Guid ChunkId { get; set; }
public string FileName { get; set; } = null!;
public string Quote { get; set; } = null!;
public int? PageNumber { get; set; }
public int IndexOnPage { get; set; }
}
+3 -3
View File
@@ -1,10 +1,10 @@
namespace SqlDatabaseVectorSearch.Models; namespace SqlDatabaseVectorSearch.Models;
// Question and Answer can be null when using response streaming. // Question and Answer can be null when using response streaming.
public record class Response(string? OriginalQuestion, string? ReformulatedQuestion, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null) public record class Response(string? OriginalQuestion, string? ReformulatedQuestion, string? Answer, StreamState? StreamState = null, TokenUsageResponse? TokenUsage = null, IEnumerable<Citation>? Citations = null)
{ {
public Response(string? token, StreamState streamState, TokenUsageResponse? tokenUsageResponse = null) public Response(string? token, StreamState streamState, TokenUsageResponse? tokenUsageResponse = null, IEnumerable<Citation>? citations = null)
: this(null, null, token, streamState, tokenUsageResponse) : this(null, null, token, streamState, tokenUsageResponse, citations)
{ {
} }
} }
@@ -125,14 +125,33 @@ public class ChatService(IChatCompletionService chatCompletionService, Tokenizer
- I'm sorry, I don't have enough information to answer that question - I'm sorry, I don't have enough information to answer that question
Never answer questions that are not related to this chat. Never answer questions that are not related to this chat.
You must answer in the same language as the user's question. You must answer in the same language as the user's question. For example, if the user asks a question in English, the answer must be in English, no matter the language of the documents.
The quote in each <citation> MUST be MAXIMUM 5 words, taken word-for-word from the search result. If the quote is longer than 5 words, your answer is INVALID. After the answer, you need to include citations following the XML format below:
When you find an answer, you MUST place ALL citations ONLY at the very end of your response, never inside or between sentences. <citation document-id='document_id' chunk-id='chunk_id' filename='string' page-number='page_number' index-on-page='index_on-page'>exact quote here</citation>
First provide your complete answer, then list all citations. <citation document-id='document_id' chunk-id='chunk_id' filename='string' page-number='page_number' index-on-page='index_on-page'>exact quote here</citation>
Use this XML format for citations: The entire list of XML citations MUST be enclosed between and (U+3010 and U+3011) and must exactly match the above format.
<citation document-id='document_id' chunk-id='chunk_id' filename='string' page-number='page_number' index-on-page='index_on_page'>exact quote here</citation> The quote in each <citation> MUST be MAXIMUM 5 words, taken word-for-word from the search result.
IMPORTANT CITATION RULES:
1. NEVER put citations inside your answer text.
2. ALWAYS provide your complete answer FIRST.
3. ONLY AFTER completing your answer, add ALL citations in a block at the very end.
4. The citations block MUST be the last thing in your response.
5. NEVER reference citations by number or mention them in your answer text.
6. The citations MUST ALWAYS follow the XML format exactly as shown below. Any other format is NOT ACCEPTED.
---
Example of a correct answer:
The capital of France is Paris.
<citation document-id='123' chunk-id='456' filename='france.pdf' page-number='1' index-on-page='1'>capital of France is Paris</citation>
Example of an incorrect answer (NOT ACCEPTED):
The capital of France is Paris [1].
[1] france.pdf, page 1
---
Only the correct format is accepted. If you do not follow the XML format exactly, your answer will be considered invalid.
"""); """);
var prompt = new StringBuilder($""" var prompt = new StringBuilder($"""
@@ -1,5 +1,7 @@
using System.Data; using System.Data;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Text;
using System.Text.RegularExpressions;
using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.AI; using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
@@ -74,9 +76,12 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
// It the user doesn't want to reforulate the question, CreateContextAsync returns the original one. // It the user doesn't want to reforulate the question, CreateContextAsync returns the original one.
var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken); var (reformulatedQuestion, embeddingTokenCount, chunks) = await CreateContextAsync(question, reformulate, cancellationToken);
var (answer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!, cancellationToken); var (fullAnswer, tokenUsage) = await chatService.AskQuestionAsync(question.ConversationId, chunks, reformulatedQuestion.Text!, cancellationToken);
return new(question.Text, reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage)); // Extract citations from the answer
var (answer, citations) = ExtractCitations(fullAnswer);
return new(question.Text, reformulatedQuestion.Text!, answer, null, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, tokenUsage), citations);
} }
public async IAsyncEnumerable<Response> AskStreamingAsync(Question question, bool reformulate = true, [EnumeratorCancellation] CancellationToken cancellationToken = default) public async IAsyncEnumerable<Response> AskStreamingAsync(Question question, bool reformulate = true, [EnumeratorCancellation] CancellationToken cancellationToken = default)
@@ -90,19 +95,42 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
yield return new(question.Text, reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null)); yield return new(question.Text, reformulatedQuestion.Text!, null, StreamState.Start, new(reformulatedQuestion.TokenUsage, embeddingTokenCount, null));
TokenUsageResponse? tokenUsageResponse = null; TokenUsageResponse? tokenUsageResponse = null;
var fullAnswer = new StringBuilder();
var areCitationsStarted = false;
// Return each token as a partial response. // Return each token as a partial response.
await foreach (var (token, tokenUsage) in answerStream) await foreach (var (token, tokenUsage) in answerStream)
{ {
// Token usage is expected in the last message. fullAnswer.Append(token);
tokenUsageResponse = tokenUsage is not null ? new(tokenUsage) : null;
yield return new(token, tokenUsageResponse is null ? StreamState.Append : StreamState.End, tokenUsageResponse); if (token?.Contains('【') == true)
{
// Citations are started when the first token contains a 【 character.
// We need to track it because we don't want to return the citations in the actual response.
areCitationsStarted = true;
} }
// If the token usage has not been returned in the last message, we must explicitly tells that the stream is ended. if (!areCitationsStarted)
{
yield return new(token, StreamState.Append);
}
// Token usage is expected in the last message.
tokenUsageResponse = tokenUsage is not null ? new(tokenUsage) : null;
if (tokenUsageResponse is not null)
{
// Response is complete, we can return the citations.
var (_, citations) = ExtractCitations(fullAnswer.ToString());
yield return new(null, StreamState.End, tokenUsageResponse, citations);
}
}
// If the token usage has not been returned in the last message, we must explicitly tell that the stream is ended.
if (tokenUsageResponse is null) if (tokenUsageResponse is null)
{ {
yield return new(null, StreamState.End); // Extract citations at the end of streaming.
var (_, citations) = ExtractCitations(fullAnswer.ToString());
yield return new(null, StreamState.End, null, citations);
} }
} }
@@ -124,4 +152,35 @@ public class VectorSearchService(IServiceProvider serviceProvider, ApplicationDb
return (reformulatedQuestion, embeddingTokenCount, chunks); return (reformulatedQuestion, embeddingTokenCount, chunks);
} }
private static (string, IEnumerable<Citation>) ExtractCitations(string? text)
{
var citations = new List<Citation>();
if (string.IsNullOrEmpty(text))
{
return (text ?? string.Empty, citations);
}
var matches = Regex.Matches(text, @"<citation\s+document-id='(?<documentId>[^']*)'\s+chunk-id='(?<chunkId>[^']*)'\s+filename='(?<filename>[^']*)'\s+page-number='(?<pageNumber>[^']*)'\s+index-on-page='(?<indexOnPage>[^']*)'>\s*(?<quote>.*?)\s*</citation>", RegexOptions.Singleline);
foreach (Match match in matches)
{
if (match.Success)
{
citations.Add(new Citation
{
DocumentId = Guid.Parse(match.Groups["documentId"].Value),
ChunkId = Guid.Parse(match.Groups["chunkId"].Value),
FileName = match.Groups["filename"].Value,
PageNumber = int.TryParse(match.Groups["pageNumber"].Value, out var pageNumber) && pageNumber > 0 ? pageNumber : null,
IndexOnPage = int.TryParse(match.Groups["indexOnPage"].Value, out var indexOnPage) ? indexOnPage : 0,
Quote = match.Groups["quote"].Value
});
}
}
// Remove all content between 【 and 】
var cleanText = Regex.Replace(text, @"【.*?】", string.Empty, RegexOptions.Singleline).TrimEnd();
return (cleanText, citations);
}
} }