Skip to content

Commit e584def

Browse files
committed
GH-2123: Add chunkOverlap support to TokenTextSplitter
Fixes GH-2123 (#2123) - Add chunkOverlap field and configuration to TokenTextSplitter class - Implement overlap functionality in doSplit method with boundary optimization - Add optimizeChunkBoundary method for sentence-aware chunk splitting - Add validation to ensure chunkOverlap < chunkSize - Update Builder pattern with withChunkOverlap method - Add comprehensive test coverage for overlap functionality - Improve existing tests to handle dynamic chunk counts Signed-off-by: Seunghwan Jung <[email protected]>
1 parent e0ccc13 commit e584def

File tree

2 files changed

+235
-55
lines changed

2 files changed

+235
-55
lines changed

spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java

Lines changed: 77 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,14 @@
3333
* @author Raphael Yu
3434
* @author Christian Tzolov
3535
* @author Ricken Bazolo
36+
* @author Seunghwan Jung
3637
*/
3738
public class TokenTextSplitter extends TextSplitter {
3839

3940
private static final int DEFAULT_CHUNK_SIZE = 800;
4041

42+
private static final int DEFAULT_CHUNK_OVERLAP = 50;
43+
4144
private static final int MIN_CHUNK_SIZE_CHARS = 350;
4245

4346
private static final int MIN_CHUNK_LENGTH_TO_EMBED = 5;
@@ -46,13 +49,17 @@ public class TokenTextSplitter extends TextSplitter {
4649

4750
private static final boolean KEEP_SEPARATOR = true;
4851

52+
4953
private final EncodingRegistry registry = Encodings.newLazyEncodingRegistry();
5054

5155
private final Encoding encoding = this.registry.getEncoding(EncodingType.CL100K_BASE);
5256

5357
// The target size of each text chunk in tokens
5458
private final int chunkSize;
5559

60+
// The overlap size of each text chunk in tokens
61+
private final int chunkOverlap;
62+
5663
// The minimum size of each text chunk in characters
5764
private final int minChunkSizeChars;
5865

@@ -65,16 +72,18 @@ public class TokenTextSplitter extends TextSplitter {
6572
private final boolean keepSeparator;
6673

6774
public TokenTextSplitter() {
68-
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR);
75+
this(DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR);
6976
}
7077

7178
public TokenTextSplitter(boolean keepSeparator) {
72-
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator);
79+
this(DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator);
7380
}
7481

75-
public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks,
76-
boolean keepSeparator) {
82+
public TokenTextSplitter(int chunkSize, int chunkOverlap, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks,
83+
boolean keepSeparator) {
84+
Assert.isTrue(chunkOverlap < chunkSize, "chunk overlap must be less than chunk size");
7785
this.chunkSize = chunkSize;
86+
this.chunkOverlap = chunkOverlap;
7887
this.minChunkSizeChars = minChunkSizeChars;
7988
this.minChunkLengthToEmbed = minChunkLengthToEmbed;
8089
this.maxNumChunks = maxNumChunks;
@@ -87,57 +96,80 @@ public static Builder builder() {
8796

8897
@Override
8998
protected List<String> splitText(String text) {
90-
return doSplit(text, this.chunkSize);
99+
return doSplit(text, this.chunkSize, this.chunkOverlap);
91100
}
92101

93-
protected List<String> doSplit(String text, int chunkSize) {
102+
protected List<String> doSplit(String text, int chunkSize, int chunkOverlap) {
94103
if (text == null || text.trim().isEmpty()) {
95104
return new ArrayList<>();
96105
}
97106

98107
List<Integer> tokens = getEncodedTokens(text);
99-
List<String> chunks = new ArrayList<>();
100-
int num_chunks = 0;
101-
while (!tokens.isEmpty() && num_chunks < this.maxNumChunks) {
102-
List<Integer> chunk = tokens.subList(0, Math.min(chunkSize, tokens.size()));
103-
String chunkText = decodeTokens(chunk);
104-
105-
// Skip the chunk if it is empty or whitespace
106-
if (chunkText.trim().isEmpty()) {
107-
tokens = tokens.subList(chunk.size(), tokens.size());
108-
continue;
109-
}
108+
// If text is smaller than chunk size, return as a single chunk
109+
if (tokens.size() <= chunkSize) {
110+
String processedText = this.keepSeparator ? text.trim() :
111+
text.replace(System.lineSeparator(), " ").trim();
110112

111-
// Find the last period or punctuation mark in the chunk
112-
int lastPunctuation = Math.max(chunkText.lastIndexOf('.'), Math.max(chunkText.lastIndexOf('?'),
113-
Math.max(chunkText.lastIndexOf('!'), chunkText.lastIndexOf('\n'))));
114-
115-
if (lastPunctuation != -1 && lastPunctuation > this.minChunkSizeChars) {
116-
// Truncate the chunk text at the punctuation mark
117-
chunkText = chunkText.substring(0, lastPunctuation + 1);
113+
if (processedText.length() > this.minChunkLengthToEmbed) {
114+
return List.of(processedText);
118115
}
116+
return new ArrayList<>();
117+
}
118+
List<String> chunks = new ArrayList<>();
119119

120-
String chunkTextToAppend = (this.keepSeparator) ? chunkText.trim()
121-
: chunkText.replace(System.lineSeparator(), " ").trim();
122-
if (chunkTextToAppend.length() > this.minChunkLengthToEmbed) {
123-
chunks.add(chunkTextToAppend);
120+
int position = 0;
121+
int num_chunks = 0;
122+
while (position < tokens.size() && num_chunks < this.maxNumChunks) {
123+
int chunkEnd = Math.min(position + chunkSize, tokens.size());
124+
125+
// Extract tokens for this chunk
126+
List<Integer> chunkTokens = tokens.subList(position, chunkEnd);
127+
String chunkText = decodeTokens(chunkTokens);
128+
129+
// Apply sentence boundary optimization
130+
String finalChunkText = optimizeChunkBoundary(chunkText);
131+
int finalChunkTokenCount = getEncodedTokens(finalChunkText).size();
132+
int advance = Math.max(1, finalChunkTokenCount - chunkOverlap);
133+
position += advance;
134+
135+
// Format according to keepSeparator setting
136+
String formattedChunk = this.keepSeparator ? finalChunkText.trim() :
137+
finalChunkText.replace(System.lineSeparator(), " ").trim();
138+
139+
// Add chunk if it meets minimum length
140+
if (formattedChunk.length() > this.minChunkLengthToEmbed) {
141+
chunks.add(formattedChunk);
142+
num_chunks++;
124143
}
144+
}
125145

126-
// Remove the tokens corresponding to the chunk text from the remaining tokens
127-
tokens = tokens.subList(getEncodedTokens(chunkText).size(), tokens.size());
146+
return chunks;
147+
}
128148

129-
num_chunks++;
149+
private String optimizeChunkBoundary(String chunkText) {
150+
if (chunkText.length() <= this.minChunkSizeChars) {
151+
return chunkText;
130152
}
131153

132-
// Handle the remaining tokens
133-
if (!tokens.isEmpty()) {
134-
String remaining_text = decodeTokens(tokens).replace(System.lineSeparator(), " ").trim();
135-
if (remaining_text.length() > this.minChunkLengthToEmbed) {
136-
chunks.add(remaining_text);
154+
// Look for sentence endings: . ! ? \n
155+
int bestCutPoint = -1;
156+
157+
// Check in reverse order to find the last sentence ending
158+
for (int i = chunkText.length() - 1; i >= this.minChunkSizeChars; i--) {
159+
char c = chunkText.charAt(i);
160+
if (c == '.' || c == '!' || c == '?' || c == '\n') {
161+
bestCutPoint = i + 1; // Include the punctuation
162+
break;
137163
}
138164
}
139165

140-
return chunks;
166+
// If we found a good cut point, use it
167+
if (bestCutPoint > 0) {
168+
return chunkText.substring(0, bestCutPoint);
169+
}
170+
171+
// Otherwise return the original chunk
172+
return chunkText;
141173
}
142174

143175
private List<Integer> getEncodedTokens(String text) {
@@ -156,6 +188,8 @@ public static final class Builder {
156188

157189
private int chunkSize = DEFAULT_CHUNK_SIZE;
158190

191+
private int chunkOverlap = DEFAULT_CHUNK_OVERLAP;
192+
159193
private int minChunkSizeChars = MIN_CHUNK_SIZE_CHARS;
160194

161195
private int minChunkLengthToEmbed = MIN_CHUNK_LENGTH_TO_EMBED;
@@ -172,6 +206,11 @@ public Builder withChunkSize(int chunkSize) {
172206
return this;
173207
}
174208

209+
public Builder withChunkOverlap(int chunkOverlap) {
210+
this.chunkOverlap = chunkOverlap;
211+
return this;
212+
}
213+
175214
public Builder withMinChunkSizeChars(int minChunkSizeChars) {
176215
this.minChunkSizeChars = minChunkSizeChars;
177216
return this;
@@ -193,7 +232,7 @@ public Builder withKeepSeparator(boolean keepSeparator) {
193232
}
194233

195234
public TokenTextSplitter build() {
196-
return new TokenTextSplitter(this.chunkSize, this.minChunkSizeChars, this.minChunkLengthToEmbed,
235+
return new TokenTextSplitter(this.chunkSize, this.chunkOverlap, this.minChunkSizeChars, this.minChunkLengthToEmbed,
197236
this.maxNumChunks, this.keepSeparator);
198237
}
199238

0 commit comments

Comments
 (0)