diff --git a/SqlDatabaseVectorSearch/TextChunkers/Implementations/PlainTextChunker.cs b/SqlDatabaseVectorSearch/TextChunkers/Implementations/PlainTextChunker.cs index be81fd9..80c25e0 100644 --- a/SqlDatabaseVectorSearch/TextChunkers/Implementations/PlainTextChunker.cs +++ b/SqlDatabaseVectorSearch/TextChunkers/Implementations/PlainTextChunker.cs @@ -44,7 +44,6 @@ internal static class PlainTextChunker /// The number of tokens in the input string. public delegate int TokenCounter(string input); - private static readonly char[] spaceChar = [' ']; private static readonly string?[] plainTextSplitOptions = ["\n", ".。.", "?!", ";", ":", ",,、", ")]}", " ", "-", null]; private static readonly string?[] markdownSplitOptions = [".\u3002\uFF0E", "?!", ";", ":", ",\uFF0C\u3001", ")]}", " ", "-", "\n\r", null]; @@ -55,8 +54,13 @@ internal static class PlainTextChunker /// Maximum number of tokens per line. /// Function to count tokens in a string. If not supplied, the default counter will be used. /// List of lines. - public static List SplitPlainTextLines(string text, int maxTokensPerLine, TokenCounter? tokenCounter = null) => - InternalSplitLines(text, maxTokensPerLine, trim: true, plainTextSplitOptions, tokenCounter); + public static List SplitPlainTextLines(string text, int maxTokensPerLine, TokenCounter? tokenCounter = null) + { + ArgumentNullException.ThrowIfNull(text); + ValidateMaxTokens(maxTokensPerLine, nameof(maxTokensPerLine)); + + return InternalSplitLines(text, maxTokensPerLine, trim: true, plainTextSplitOptions, tokenCounter); + } /// /// Split markdown text into lines. @@ -65,8 +69,13 @@ internal static class PlainTextChunker /// Maximum number of tokens per line. /// Function to count tokens in a string. If not supplied, the default counter will be used. /// List of lines. - public static List SplitMarkDownLines(string text, int maxTokensPerLine, TokenCounter? tokenCounter = null) => - InternalSplitLines(text, maxTokensPerLine, trim: true, markdownSplitOptions, tokenCounter); + public static List SplitMarkdownLines(string text, int maxTokensPerLine, TokenCounter? tokenCounter = null) + { + ArgumentNullException.ThrowIfNull(text); + ValidateMaxTokens(maxTokensPerLine, nameof(maxTokensPerLine)); + + return InternalSplitLines(text, maxTokensPerLine, trim: true, markdownSplitOptions, tokenCounter); + } /// /// Split plain text into paragraphs. @@ -78,7 +87,7 @@ internal static class PlainTextChunker /// Function to count tokens in a string. If not supplied, the default counter will be used. /// List of paragraphs. public static List SplitPlainTextParagraphs(IEnumerable lines, int maxTokensPerParagraph, int overlapTokens = 0, string? chunkHeader = null, TokenCounter? tokenCounter = null) - => InternalSplitTextParagraphs(lines.Select(line => line.Replace("\r\n", "\n").Replace('\r', '\n')), maxTokensPerParagraph, overlapTokens, chunkHeader, + => InternalSplitTextParagraphs(lines, maxTokensPerParagraph, overlapTokens, chunkHeader, static (text, maxTokens, tokenCounter) => InternalSplitLines(text, maxTokens, trim: false, plainTextSplitOptions, tokenCounter), tokenCounter); /// @@ -96,17 +105,20 @@ internal static class PlainTextChunker private static List InternalSplitTextParagraphs(IEnumerable lines, int maxTokensPerParagraph, int overlapTokens, string? chunkHeader, Func> longLinesSplitter, TokenCounter? tokenCounter) { - if (maxTokensPerParagraph <= 0) + ArgumentNullException.ThrowIfNull(lines); + ValidateMaxTokens(maxTokensPerParagraph, nameof(maxTokensPerParagraph)); + + if (overlapTokens < 0) { - throw new ArgumentException("maxTokensPerParagraph should be a positive number", nameof(maxTokensPerParagraph)); + throw new ArgumentOutOfRangeException(nameof(overlapTokens), "overlapTokens cannot be negative."); } if (maxTokensPerParagraph <= overlapTokens) { - throw new ArgumentException("overlapTokens cannot be larger than maxTokensPerParagraph", nameof(maxTokensPerParagraph)); + throw new ArgumentException("overlapTokens cannot be larger than or equal to maxTokensPerParagraph.", nameof(overlapTokens)); } - // Optimize empty inputs if we can efficiently determine the're empty + // Optimize empty inputs if we can efficiently determine they're empty. if (lines is ICollection c && c.Count == 0) { return []; @@ -114,9 +126,13 @@ internal static class PlainTextChunker var chunkHeaderTokens = chunkHeader is { Length: > 0 } ? GetTokenCount(chunkHeader, tokenCounter) : 0; var adjustedMaxTokensPerParagraph = maxTokensPerParagraph - overlapTokens - chunkHeaderTokens; + if (adjustedMaxTokensPerParagraph <= 0) + { + throw new ArgumentException("chunkHeader and overlapTokens must leave room for paragraph content.", nameof(chunkHeader)); + } // Split long lines first - var truncatedLines = lines.SelectMany(line => longLinesSplitter(line, adjustedMaxTokensPerParagraph, tokenCounter)); + var truncatedLines = lines.SelectMany(line => longLinesSplitter(NormalizeLineEndings(line), adjustedMaxTokensPerParagraph, tokenCounter)); var paragraphs = BuildParagraph(truncatedLines, adjustedMaxTokensPerParagraph, tokenCounter); var processedParagraphs = ProcessParagraphs(paragraphs, adjustedMaxTokensPerParagraph, overlapTokens, chunkHeader, longLinesSplitter, tokenCounter); @@ -174,18 +190,11 @@ internal static class PlainTextChunker if (GetTokenCount(lastParagraph, tokenCounter) < adjustedMaxTokensPerParagraph / 4) { - var lastParagraphTokens = lastParagraph.Split(spaceChar, StringSplitOptions.RemoveEmptyEntries); - var secondLastParagraphTokens = secondLastParagraph.Split(spaceChar, StringSplitOptions.RemoveEmptyEntries); + var mergedParagraph = $"{secondLastParagraph} {lastParagraph}"; - var lastParagraphTokensCount = lastParagraphTokens.Length; - var secondLastParagraphTokensCount = secondLastParagraphTokens.Length; - - if (lastParagraphTokensCount + secondLastParagraphTokensCount <= adjustedMaxTokensPerParagraph) + if (GetTokenCount(mergedParagraph, tokenCounter) <= adjustedMaxTokensPerParagraph) { - var newSecondLastParagraph = string.Join(" ", secondLastParagraphTokens); - var newLastParagraph = string.Join(" ", lastParagraphTokens); - - paragraphs[^2] = $"{newSecondLastParagraph} {newLastParagraph}"; + paragraphs[^2] = mergedParagraph; paragraphs.RemoveAt(paragraphs.Count - 1); } } @@ -232,7 +241,7 @@ internal static class PlainTextChunker { var result = new StringListWithTokenCount(tokenCounter); - text = text.Replace("\r\n", "\n"); // normalize line endings + text = NormalizeLineEndings(text); result.Add(text); for (var i = 0; i < splitOptions.Length; i++) { @@ -314,9 +323,9 @@ internal static class PlainTextChunker } // Recursion - var (splits1, split1) = Split(firstHalf, null, maxTokens, separators, trim, tokenCounter, GetTokenCount(firstHalf.ToString(), tokenCounter)); + var (splits1, split1) = Split(firstHalf, null, maxTokens, separators, trim, tokenCounter, GetTokenCount(firstHalf, tokenCounter)); result.AddRange(splits1); - var (splits2, split2) = Split(secondHalf, null, maxTokens, separators, trim, tokenCounter, GetTokenCount(secondHalf.ToString(), tokenCounter)); + var (splits2, split2) = Split(secondHalf, null, maxTokens, separators, trim, tokenCounter, GetTokenCount(secondHalf, tokenCounter)); result.AddRange(splits2); inputWasSplit = split1 || split2; @@ -326,10 +335,14 @@ internal static class PlainTextChunker var resultString = inputString ?? input.ToString(); var resultTokenCount = inputTokenCount; - if (trim && !resultString.Trim().Equals(resultString, StringComparison.Ordinal)) + if (trim) { - resultString = resultString.Trim(); - resultTokenCount = GetTokenCount(resultString, tokenCounter); + var trimmedResult = resultString.Trim(); + if (!trimmedResult.Equals(resultString, StringComparison.Ordinal)) + { + resultString = trimmedResult; + resultTokenCount = GetTokenCount(resultString, tokenCounter); + } } result.Add(resultString, resultTokenCount); @@ -339,9 +352,21 @@ internal static class PlainTextChunker private static int GetTokenCount(string input, TokenCounter? tokenCounter) => tokenCounter is null ? GetDefaultTokenCount(input.Length) : tokenCounter(input); + private static int GetTokenCount(ReadOnlySpan input, TokenCounter? tokenCounter) => tokenCounter is null ? GetDefaultTokenCount(input.Length) : tokenCounter(input.ToString()); + + private static string NormalizeLineEndings(string text) => text.Replace("\r\n", "\n").Replace('\r', '\n'); + + private static void ValidateMaxTokens(int maxTokens, string parameterName) + { + if (maxTokens <= 0) + { + throw new ArgumentOutOfRangeException(parameterName, "The maximum token count must be a positive number."); + } + } + private static int GetDefaultTokenCount(int length) { Debug.Assert(length >= 0); - return length >> 2; + return length == 0 ? 0 : Math.Max(1, length >> 2); } } \ No newline at end of file diff --git a/SqlDatabaseVectorSearch/TextChunkers/MarkdownTextChunker.cs b/SqlDatabaseVectorSearch/TextChunkers/MarkdownTextChunker.cs index 37a06e7..ebea231 100644 --- a/SqlDatabaseVectorSearch/TextChunkers/MarkdownTextChunker.cs +++ b/SqlDatabaseVectorSearch/TextChunkers/MarkdownTextChunker.cs @@ -11,7 +11,7 @@ public class MarkdownTextChunker(TokenizerService tokenizerService, IOptions Split(string text) { - var lines = PlainTextChunker.SplitMarkDownLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountEmbeddingTokens); + var lines = PlainTextChunker.SplitMarkdownLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountEmbeddingTokens); var paragraphs = PlainTextChunker.SplitMarkdownParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountEmbeddingTokens); return paragraphs;