Skip to content

Commit 9773099

Browse files
oneby-wangilayaperumalg
authored andcommitted
feat: Support custom punctuation marks in TokenTextSplitter
Signed-off-by: oneby-wang <[email protected]>
1 parent 9e857ec commit 9773099

File tree

2 files changed

+70
-6
lines changed

2 files changed

+70
-6
lines changed

spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ public class TokenTextSplitter extends TextSplitter {
4646

4747
private static final boolean KEEP_SEPARATOR = true;
4848

49+
private static final List<Character> DEFAULT_PUNCTUATION_MARKS = List.of('.', '?', '!', '\n');
50+
4951
private final EncodingRegistry registry = Encodings.newLazyEncodingRegistry();
5052

5153
private final Encoding encoding = this.registry.getEncoding(EncodingType.CL100K_BASE);
@@ -64,21 +66,27 @@ public class TokenTextSplitter extends TextSplitter {
6466

6567
private final boolean keepSeparator;
6668

69+
private final List<Character> punctuationMarks;
70+
6771
public TokenTextSplitter() {
68-
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR);
72+
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR,
73+
DEFAULT_PUNCTUATION_MARKS);
6974
}
7075

7176
public TokenTextSplitter(boolean keepSeparator) {
72-
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator);
77+
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator,
78+
DEFAULT_PUNCTUATION_MARKS);
7379
}
7480

7581
public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks,
76-
boolean keepSeparator) {
82+
boolean keepSeparator, List<Character> punctuationMarks) {
7783
this.chunkSize = chunkSize;
7884
this.minChunkSizeChars = minChunkSizeChars;
7985
this.minChunkLengthToEmbed = minChunkLengthToEmbed;
8086
this.maxNumChunks = maxNumChunks;
8187
this.keepSeparator = keepSeparator;
88+
Assert.notEmpty(punctuationMarks, "punctuationMarks must not be empty");
89+
this.punctuationMarks = punctuationMarks;
8290
}
8391

8492
public static Builder builder() {
@@ -124,8 +132,7 @@ protected List<String> doSplit(String text, int chunkSize) {
124132
// This prevents unnecessary splitting of small texts
125133
if (tokens.size() > chunkSize) {
126134
// Find the last period or punctuation mark in the chunk
127-
int lastPunctuation = Math.max(chunkText.lastIndexOf('.'), Math.max(chunkText.lastIndexOf('?'),
128-
Math.max(chunkText.lastIndexOf('!'), chunkText.lastIndexOf('\n'))));
135+
int lastPunctuation = getLastPunctuationIndex(chunkText);
129136

130137
if (lastPunctuation != -1 && lastPunctuation > this.minChunkSizeChars) {
131138
// Truncate the chunk text at the punctuation mark
@@ -156,6 +163,16 @@ protected List<String> doSplit(String text, int chunkSize) {
156163
return chunks;
157164
}
158165

166+
protected int getLastPunctuationIndex(String chunkText) {
167+
// find the max index of any punctuation mark
168+
int maxLastPunctuation = -1;
169+
for (Character punctuationMark : this.punctuationMarks) {
170+
int lastPunctuation = chunkText.lastIndexOf(punctuationMark);
171+
maxLastPunctuation = Math.max(maxLastPunctuation, lastPunctuation);
172+
}
173+
return maxLastPunctuation;
174+
}
175+
159176
private List<Integer> getEncodedTokens(String text) {
160177
Assert.notNull(text, "Text must not be null");
161178
return this.encoding.encode(text).boxed();
@@ -180,6 +197,8 @@ public static final class Builder {
180197

181198
private boolean keepSeparator = KEEP_SEPARATOR;
182199

200+
private List<Character> punctuationMarks = DEFAULT_PUNCTUATION_MARKS;
201+
183202
private Builder() {
184203
}
185204

@@ -208,9 +227,14 @@ public Builder withKeepSeparator(boolean keepSeparator) {
208227
return this;
209228
}
210229

230+
public Builder withPunctuationMarks(List<Character> punctuationMarks) {
231+
this.punctuationMarks = punctuationMarks;
232+
return this;
233+
}
234+
211235
public TokenTextSplitter build() {
212236
return new TokenTextSplitter(this.chunkSize, this.minChunkSizeChars, this.minChunkLengthToEmbed,
213-
this.maxNumChunks, this.keepSeparator);
237+
this.maxNumChunks, this.keepSeparator, this.punctuationMarks);
214238
}
215239

216240
}

spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,4 +165,44 @@ public void testLargeTextStillSplitsAtPunctuation() {
165165
assertThat(splitted.get(0).getText()).endsWith(".");
166166
}
167167

168+
@Test
169+
public void testTokenTextSplitterWithCustomPunctuationMarks() {
170+
var contentFormatter1 = DefaultContentFormatter.defaultConfig();
171+
var contentFormatter2 = DefaultContentFormatter.defaultConfig();
172+
173+
assertThat(contentFormatter1).isNotSameAs(contentFormatter2);
174+
175+
var doc1 = new Document("Here, we set custom punctuation marks。?!. We just want to test it works or not?");
176+
doc1.setContentFormatter(contentFormatter1);
177+
178+
var doc2 = new Document("And more, we add protected method getLastPunctuationIndex in TokenTextSplitter class!"
179+
+ "The subclasses can override this method to achieve their own business logic。We just want to test it works or not?");
180+
doc2.setContentFormatter(contentFormatter2);
181+
182+
var tokenTextSplitter = TokenTextSplitter.builder()
183+
.withChunkSize(10)
184+
.withMinChunkSizeChars(5)
185+
.withMinChunkLengthToEmbed(3)
186+
.withMaxNumChunks(50)
187+
.withKeepSeparator(true)
188+
.withPunctuationMarks(List.of('。', '?', '!'))
189+
.build();
190+
191+
var chunks = tokenTextSplitter.apply(List.of(doc1, doc2));
192+
193+
assertThat(chunks.size()).isEqualTo(7);
194+
195+
// Doc 1
196+
assertThat(chunks.get(0).getText()).isEqualTo("Here, we set custom punctuation marks。?!");
197+
assertThat(chunks.get(1).getText()).isEqualTo(". We just want to test it works or not");
198+
199+
// Doc 2
200+
assertThat(chunks.get(2).getText()).isEqualTo("And more, we add protected method getLastPunctuation");
201+
assertThat(chunks.get(3).getText()).isEqualTo("Index in TokenTextSplitter class!");
202+
assertThat(chunks.get(4).getText()).isEqualTo("The subclasses can override this method to achieve their own");
203+
assertThat(chunks.get(5).getText()).isEqualTo("business logic。");
204+
assertThat(chunks.get(6).getText()).isEqualTo("We just want to test it works or not?");
205+
206+
}
207+
168208
}

0 commit comments

Comments
 (0)