33
33
* @author Raphael Yu
34
34
* @author Christian Tzolov
35
35
* @author Ricken Bazolo
36
+ * @author Seunghwan Jung
36
37
*/
37
38
public class TokenTextSplitter extends TextSplitter {
38
39
39
40
private static final int DEFAULT_CHUNK_SIZE = 800 ;
40
41
42
+ private static final int DEFAULT_CHUNK_OVERLAP = 50 ;
43
+
41
44
private static final int MIN_CHUNK_SIZE_CHARS = 350 ;
42
45
43
46
private static final int MIN_CHUNK_LENGTH_TO_EMBED = 5 ;
@@ -46,13 +49,17 @@ public class TokenTextSplitter extends TextSplitter {
46
49
47
50
private static final boolean KEEP_SEPARATOR = true ;
48
51
52
+
49
53
private final EncodingRegistry registry = Encodings .newLazyEncodingRegistry ();
50
54
51
55
private final Encoding encoding = this .registry .getEncoding (EncodingType .CL100K_BASE );
52
56
53
57
// The target size of each text chunk in tokens
54
58
private final int chunkSize ;
55
59
60
+ // The overlap size of each text chunk in tokens
61
+ private final int chunkOverlap ;
62
+
56
63
// The minimum size of each text chunk in characters
57
64
private final int minChunkSizeChars ;
58
65
@@ -65,16 +72,18 @@ public class TokenTextSplitter extends TextSplitter {
65
72
private final boolean keepSeparator ;
66
73
67
74
public TokenTextSplitter () {
68
- this (DEFAULT_CHUNK_SIZE , MIN_CHUNK_SIZE_CHARS , MIN_CHUNK_LENGTH_TO_EMBED , MAX_NUM_CHUNKS , KEEP_SEPARATOR );
75
+ this (DEFAULT_CHUNK_SIZE , DEFAULT_CHUNK_OVERLAP , MIN_CHUNK_SIZE_CHARS , MIN_CHUNK_LENGTH_TO_EMBED , MAX_NUM_CHUNKS , KEEP_SEPARATOR );
69
76
}
70
77
71
78
public TokenTextSplitter (boolean keepSeparator ) {
72
- this (DEFAULT_CHUNK_SIZE , MIN_CHUNK_SIZE_CHARS , MIN_CHUNK_LENGTH_TO_EMBED , MAX_NUM_CHUNKS , keepSeparator );
79
+ this (DEFAULT_CHUNK_SIZE , DEFAULT_CHUNK_OVERLAP , MIN_CHUNK_SIZE_CHARS , MIN_CHUNK_LENGTH_TO_EMBED , MAX_NUM_CHUNKS , keepSeparator );
73
80
}
74
81
75
- public TokenTextSplitter (int chunkSize , int minChunkSizeChars , int minChunkLengthToEmbed , int maxNumChunks ,
76
- boolean keepSeparator ) {
82
+ public TokenTextSplitter (int chunkSize , int chunkOverlap , int minChunkSizeChars , int minChunkLengthToEmbed , int maxNumChunks ,
83
+ boolean keepSeparator ) {
84
+ Assert .isTrue (chunkOverlap < chunkSize , "chunk overlap must be less than chunk size" );
77
85
this .chunkSize = chunkSize ;
86
+ this .chunkOverlap = chunkOverlap ;
78
87
this .minChunkSizeChars = minChunkSizeChars ;
79
88
this .minChunkLengthToEmbed = minChunkLengthToEmbed ;
80
89
this .maxNumChunks = maxNumChunks ;
@@ -87,57 +96,80 @@ public static Builder builder() {
87
96
88
97
@ Override
89
98
protected List <String > splitText (String text ) {
90
- return doSplit (text , this .chunkSize );
99
+ return doSplit (text , this .chunkSize , this . chunkOverlap );
91
100
}
92
101
93
- protected List <String > doSplit (String text , int chunkSize ) {
102
+ protected List <String > doSplit (String text , int chunkSize , int chunkOverlap ) {
94
103
if (text == null || text .trim ().isEmpty ()) {
95
104
return new ArrayList <>();
96
105
}
97
106
98
107
List <Integer > tokens = getEncodedTokens (text );
99
- List <String > chunks = new ArrayList <>();
100
- int num_chunks = 0 ;
101
- while (!tokens .isEmpty () && num_chunks < this .maxNumChunks ) {
102
- List <Integer > chunk = tokens .subList (0 , Math .min (chunkSize , tokens .size ()));
103
- String chunkText = decodeTokens (chunk );
104
-
105
- // Skip the chunk if it is empty or whitespace
106
- if (chunkText .trim ().isEmpty ()) {
107
- tokens = tokens .subList (chunk .size (), tokens .size ());
108
- continue ;
109
- }
108
+ // If text is smaller than chunk size, return as a single chunk
109
+ if (tokens .size () <= chunkSize ) {
110
+ String processedText = this .keepSeparator ? text .trim () :
111
+ text .replace (System .lineSeparator (), " " ).trim ();
110
112
111
- // Find the last period or punctuation mark in the chunk
112
- int lastPunctuation = Math .max (chunkText .lastIndexOf ('.' ), Math .max (chunkText .lastIndexOf ('?' ),
113
- Math .max (chunkText .lastIndexOf ('!' ), chunkText .lastIndexOf ('\n' ))));
114
-
115
- if (lastPunctuation != -1 && lastPunctuation > this .minChunkSizeChars ) {
116
- // Truncate the chunk text at the punctuation mark
117
- chunkText = chunkText .substring (0 , lastPunctuation + 1 );
113
+ if (processedText .length () > this .minChunkLengthToEmbed ) {
114
+ return List .of (processedText );
118
115
}
116
+ return new ArrayList <>();
117
+ }
118
+ List <String > chunks = new ArrayList <>();
119
119
120
- String chunkTextToAppend = (this .keepSeparator ) ? chunkText .trim ()
121
- : chunkText .replace (System .lineSeparator (), " " ).trim ();
122
- if (chunkTextToAppend .length () > this .minChunkLengthToEmbed ) {
123
- chunks .add (chunkTextToAppend );
120
+ int position = 0 ;
121
+ int num_chunks = 0 ;
122
+ while (position < tokens .size () && num_chunks < this .maxNumChunks ) {
123
+ int chunkEnd = Math .min (position + chunkSize , tokens .size ());
124
+
125
+ // Extract tokens for this chunk
126
+ List <Integer > chunkTokens = tokens .subList (position , chunkEnd );
127
+ String chunkText = decodeTokens (chunkTokens );
128
+
129
+ // Apply sentence boundary optimization
130
+ String finalChunkText = optimizeChunkBoundary (chunkText );
131
+ int finalChunkTokenCount = getEncodedTokens (finalChunkText ).size ();
132
+ int advance = Math .max (1 , finalChunkTokenCount - chunkOverlap );
133
+ position += advance ;
134
+
135
+ // Format according to keepSeparator setting
136
+ String formattedChunk = this .keepSeparator ? finalChunkText .trim () :
137
+ finalChunkText .replace (System .lineSeparator (), " " ).trim ();
138
+
139
+ // Add chunk if it meets minimum length
140
+ if (formattedChunk .length () > this .minChunkLengthToEmbed ) {
141
+ chunks .add (formattedChunk );
142
+ num_chunks ++;
124
143
}
144
+ }
125
145
126
- // Remove the tokens corresponding to the chunk text from the remaining tokens
127
- tokens = tokens . subList ( getEncodedTokens ( chunkText ). size (), tokens . size ());
146
+ return chunks ;
147
+ }
128
148
129
- num_chunks ++;
149
+ private String optimizeChunkBoundary (String chunkText ) {
150
+ if (chunkText .length () <= this .minChunkSizeChars ) {
151
+ return chunkText ;
130
152
}
131
153
132
- // Handle the remaining tokens
133
- if (!tokens .isEmpty ()) {
134
- String remaining_text = decodeTokens (tokens ).replace (System .lineSeparator (), " " ).trim ();
135
- if (remaining_text .length () > this .minChunkLengthToEmbed ) {
136
- chunks .add (remaining_text );
154
+ // Look for sentence endings: . ! ? \n
155
+ int bestCutPoint = -1 ;
156
+
157
+ // Check in reverse order to find the last sentence ending
158
+ for (int i = chunkText .length () - 1 ; i >= this .minChunkSizeChars ; i --) {
159
+ char c = chunkText .charAt (i );
160
+ if (c == '.' || c == '!' || c == '?' || c == '\n' ) {
161
+ bestCutPoint = i + 1 ; // Include the punctuation
162
+ break ;
137
163
}
138
164
}
139
165
140
- return chunks ;
166
+ // If we found a good cut point, use it
167
+ if (bestCutPoint > 0 ) {
168
+ return chunkText .substring (0 , bestCutPoint );
169
+ }
170
+
171
+ // Otherwise return the original chunk
172
+ return chunkText ;
141
173
}
142
174
143
175
private List <Integer > getEncodedTokens (String text ) {
@@ -156,6 +188,8 @@ public static final class Builder {
156
188
157
189
private int chunkSize = DEFAULT_CHUNK_SIZE ;
158
190
191
+ private int chunkOverlap = DEFAULT_CHUNK_OVERLAP ;
192
+
159
193
private int minChunkSizeChars = MIN_CHUNK_SIZE_CHARS ;
160
194
161
195
private int minChunkLengthToEmbed = MIN_CHUNK_LENGTH_TO_EMBED ;
@@ -172,6 +206,11 @@ public Builder withChunkSize(int chunkSize) {
172
206
return this ;
173
207
}
174
208
209
+ public Builder withChunkOverlap (int chunkOverlap ) {
210
+ this .chunkOverlap = chunkOverlap ;
211
+ return this ;
212
+ }
213
+
175
214
public Builder withMinChunkSizeChars (int minChunkSizeChars ) {
176
215
this .minChunkSizeChars = minChunkSizeChars ;
177
216
return this ;
@@ -193,7 +232,7 @@ public Builder withKeepSeparator(boolean keepSeparator) {
193
232
}
194
233
195
234
public TokenTextSplitter build () {
196
- return new TokenTextSplitter (this .chunkSize , this .minChunkSizeChars , this .minChunkLengthToEmbed ,
235
+ return new TokenTextSplitter (this .chunkSize , this .chunkOverlap , this . minChunkSizeChars , this .minChunkLengthToEmbed ,
197
236
this .maxNumChunks , this .keepSeparator );
198
237
}
199
238
0 commit comments