diff --git a/SqlDatabaseVectorSearch/TextChunkers/Implementations/PlainTextChunker.cs b/SqlDatabaseVectorSearch/TextChunkers/Implementations/PlainTextChunker.cs
index be81fd9..80c25e0 100644
--- a/SqlDatabaseVectorSearch/TextChunkers/Implementations/PlainTextChunker.cs
+++ b/SqlDatabaseVectorSearch/TextChunkers/Implementations/PlainTextChunker.cs
@@ -44,7 +44,6 @@ internal static class PlainTextChunker
/// The number of tokens in the input string.
public delegate int TokenCounter(string input);
- private static readonly char[] spaceChar = [' '];
private static readonly string?[] plainTextSplitOptions = ["\n", ".。.", "?!", ";", ":", ",,、", ")]}", " ", "-", null];
private static readonly string?[] markdownSplitOptions = [".\u3002\uFF0E", "?!", ";", ":", ",\uFF0C\u3001", ")]}", " ", "-", "\n\r", null];
@@ -55,8 +54,13 @@ internal static class PlainTextChunker
/// Maximum number of tokens per line.
/// Function to count tokens in a string. If not supplied, the default counter will be used.
/// List of lines.
- public static List SplitPlainTextLines(string text, int maxTokensPerLine, TokenCounter? tokenCounter = null) =>
- InternalSplitLines(text, maxTokensPerLine, trim: true, plainTextSplitOptions, tokenCounter);
+ public static List SplitPlainTextLines(string text, int maxTokensPerLine, TokenCounter? tokenCounter = null)
+ {
+ ArgumentNullException.ThrowIfNull(text);
+ ValidateMaxTokens(maxTokensPerLine, nameof(maxTokensPerLine));
+
+ return InternalSplitLines(text, maxTokensPerLine, trim: true, plainTextSplitOptions, tokenCounter);
+ }
///
/// Split markdown text into lines.
@@ -65,8 +69,13 @@ internal static class PlainTextChunker
/// Maximum number of tokens per line.
/// Function to count tokens in a string. If not supplied, the default counter will be used.
/// List of lines.
- public static List SplitMarkDownLines(string text, int maxTokensPerLine, TokenCounter? tokenCounter = null) =>
- InternalSplitLines(text, maxTokensPerLine, trim: true, markdownSplitOptions, tokenCounter);
+ public static List SplitMarkdownLines(string text, int maxTokensPerLine, TokenCounter? tokenCounter = null)
+ {
+ ArgumentNullException.ThrowIfNull(text);
+ ValidateMaxTokens(maxTokensPerLine, nameof(maxTokensPerLine));
+
+ return InternalSplitLines(text, maxTokensPerLine, trim: true, markdownSplitOptions, tokenCounter);
+ }
///
/// Split plain text into paragraphs.
@@ -78,7 +87,7 @@ internal static class PlainTextChunker
/// Function to count tokens in a string. If not supplied, the default counter will be used.
/// List of paragraphs.
public static List SplitPlainTextParagraphs(IEnumerable lines, int maxTokensPerParagraph, int overlapTokens = 0, string? chunkHeader = null, TokenCounter? tokenCounter = null)
- => InternalSplitTextParagraphs(lines.Select(line => line.Replace("\r\n", "\n").Replace('\r', '\n')), maxTokensPerParagraph, overlapTokens, chunkHeader,
+ => InternalSplitTextParagraphs(lines, maxTokensPerParagraph, overlapTokens, chunkHeader,
static (text, maxTokens, tokenCounter) => InternalSplitLines(text, maxTokens, trim: false, plainTextSplitOptions, tokenCounter), tokenCounter);
///
@@ -96,17 +105,20 @@ internal static class PlainTextChunker
private static List InternalSplitTextParagraphs(IEnumerable lines, int maxTokensPerParagraph, int overlapTokens, string? chunkHeader, Func> longLinesSplitter, TokenCounter? tokenCounter)
{
- if (maxTokensPerParagraph <= 0)
+ ArgumentNullException.ThrowIfNull(lines);
+ ValidateMaxTokens(maxTokensPerParagraph, nameof(maxTokensPerParagraph));
+
+ if (overlapTokens < 0)
{
- throw new ArgumentException("maxTokensPerParagraph should be a positive number", nameof(maxTokensPerParagraph));
+ throw new ArgumentOutOfRangeException(nameof(overlapTokens), "overlapTokens cannot be negative.");
}
if (maxTokensPerParagraph <= overlapTokens)
{
- throw new ArgumentException("overlapTokens cannot be larger than maxTokensPerParagraph", nameof(maxTokensPerParagraph));
+ throw new ArgumentException("overlapTokens cannot be larger than or equal to maxTokensPerParagraph.", nameof(overlapTokens));
}
- // Optimize empty inputs if we can efficiently determine the're empty
+ // Optimize empty inputs if we can efficiently determine they're empty.
if (lines is ICollection c && c.Count == 0)
{
return [];
@@ -114,9 +126,13 @@ internal static class PlainTextChunker
var chunkHeaderTokens = chunkHeader is { Length: > 0 } ? GetTokenCount(chunkHeader, tokenCounter) : 0;
var adjustedMaxTokensPerParagraph = maxTokensPerParagraph - overlapTokens - chunkHeaderTokens;
+ if (adjustedMaxTokensPerParagraph <= 0)
+ {
+ throw new ArgumentException("chunkHeader and overlapTokens must leave room for paragraph content.", nameof(chunkHeader));
+ }
// Split long lines first
- var truncatedLines = lines.SelectMany(line => longLinesSplitter(line, adjustedMaxTokensPerParagraph, tokenCounter));
+ var truncatedLines = lines.SelectMany(line => longLinesSplitter(NormalizeLineEndings(line), adjustedMaxTokensPerParagraph, tokenCounter));
var paragraphs = BuildParagraph(truncatedLines, adjustedMaxTokensPerParagraph, tokenCounter);
var processedParagraphs = ProcessParagraphs(paragraphs, adjustedMaxTokensPerParagraph, overlapTokens, chunkHeader, longLinesSplitter, tokenCounter);
@@ -174,18 +190,11 @@ internal static class PlainTextChunker
if (GetTokenCount(lastParagraph, tokenCounter) < adjustedMaxTokensPerParagraph / 4)
{
- var lastParagraphTokens = lastParagraph.Split(spaceChar, StringSplitOptions.RemoveEmptyEntries);
- var secondLastParagraphTokens = secondLastParagraph.Split(spaceChar, StringSplitOptions.RemoveEmptyEntries);
+ var mergedParagraph = $"{secondLastParagraph} {lastParagraph}";
- var lastParagraphTokensCount = lastParagraphTokens.Length;
- var secondLastParagraphTokensCount = secondLastParagraphTokens.Length;
-
- if (lastParagraphTokensCount + secondLastParagraphTokensCount <= adjustedMaxTokensPerParagraph)
+ if (GetTokenCount(mergedParagraph, tokenCounter) <= adjustedMaxTokensPerParagraph)
{
- var newSecondLastParagraph = string.Join(" ", secondLastParagraphTokens);
- var newLastParagraph = string.Join(" ", lastParagraphTokens);
-
- paragraphs[^2] = $"{newSecondLastParagraph} {newLastParagraph}";
+ paragraphs[^2] = mergedParagraph;
paragraphs.RemoveAt(paragraphs.Count - 1);
}
}
@@ -232,7 +241,7 @@ internal static class PlainTextChunker
{
var result = new StringListWithTokenCount(tokenCounter);
- text = text.Replace("\r\n", "\n"); // normalize line endings
+ text = NormalizeLineEndings(text);
result.Add(text);
for (var i = 0; i < splitOptions.Length; i++)
{
@@ -314,9 +323,9 @@ internal static class PlainTextChunker
}
// Recursion
- var (splits1, split1) = Split(firstHalf, null, maxTokens, separators, trim, tokenCounter, GetTokenCount(firstHalf.ToString(), tokenCounter));
+ var (splits1, split1) = Split(firstHalf, null, maxTokens, separators, trim, tokenCounter, GetTokenCount(firstHalf, tokenCounter));
result.AddRange(splits1);
- var (splits2, split2) = Split(secondHalf, null, maxTokens, separators, trim, tokenCounter, GetTokenCount(secondHalf.ToString(), tokenCounter));
+ var (splits2, split2) = Split(secondHalf, null, maxTokens, separators, trim, tokenCounter, GetTokenCount(secondHalf, tokenCounter));
result.AddRange(splits2);
inputWasSplit = split1 || split2;
@@ -326,10 +335,14 @@ internal static class PlainTextChunker
var resultString = inputString ?? input.ToString();
var resultTokenCount = inputTokenCount;
- if (trim && !resultString.Trim().Equals(resultString, StringComparison.Ordinal))
+ if (trim)
{
- resultString = resultString.Trim();
- resultTokenCount = GetTokenCount(resultString, tokenCounter);
+ var trimmedResult = resultString.Trim();
+ if (!trimmedResult.Equals(resultString, StringComparison.Ordinal))
+ {
+ resultString = trimmedResult;
+ resultTokenCount = GetTokenCount(resultString, tokenCounter);
+ }
}
result.Add(resultString, resultTokenCount);
@@ -339,9 +352,21 @@ internal static class PlainTextChunker
private static int GetTokenCount(string input, TokenCounter? tokenCounter) => tokenCounter is null ? GetDefaultTokenCount(input.Length) : tokenCounter(input);
+ private static int GetTokenCount(ReadOnlySpan input, TokenCounter? tokenCounter) => tokenCounter is null ? GetDefaultTokenCount(input.Length) : tokenCounter(input.ToString());
+
+ private static string NormalizeLineEndings(string text) => text.Replace("\r\n", "\n").Replace('\r', '\n');
+
+ private static void ValidateMaxTokens(int maxTokens, string parameterName)
+ {
+ if (maxTokens <= 0)
+ {
+ throw new ArgumentOutOfRangeException(parameterName, "The maximum token count must be a positive number.");
+ }
+ }
+
private static int GetDefaultTokenCount(int length)
{
Debug.Assert(length >= 0);
- return length >> 2;
+ return length == 0 ? 0 : Math.Max(1, length >> 2);
}
}
\ No newline at end of file
diff --git a/SqlDatabaseVectorSearch/TextChunkers/MarkdownTextChunker.cs b/SqlDatabaseVectorSearch/TextChunkers/MarkdownTextChunker.cs
index 37a06e7..ebea231 100644
--- a/SqlDatabaseVectorSearch/TextChunkers/MarkdownTextChunker.cs
+++ b/SqlDatabaseVectorSearch/TextChunkers/MarkdownTextChunker.cs
@@ -11,7 +11,7 @@ public class MarkdownTextChunker(TokenizerService tokenizerService, IOptions Split(string text)
{
- var lines = PlainTextChunker.SplitMarkDownLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountEmbeddingTokens);
+ var lines = PlainTextChunker.SplitMarkdownLines(text, appSettings.MaxTokensPerLine, tokenizerService.CountEmbeddingTokens);
var paragraphs = PlainTextChunker.SplitMarkdownParagraphs(lines, appSettings.MaxTokensPerParagraph, appSettings.OverlapTokens, tokenCounter: tokenizerService.CountEmbeddingTokens);
return paragraphs;