@@ -105,7 +105,7 @@ export class LlamaChatSession {
105
105
onToken, signal, maxTokens
106
106
} : { onToken ?( tokens : Token [ ] ) : void , signal ?: AbortSignal , maxTokens ?: number } = { } ) {
107
107
const stopStrings = this . _promptWrapper . getStopStrings ( ) ;
108
- const stopStringIndexes = Array ( stopStrings . length ) . fill ( 0 ) ;
108
+ const stopStringIndexes : number [ ] = Array ( stopStrings . length ) . fill ( 0 ) ;
109
109
const skippedChunksQueue : Token [ ] = [ ] ;
110
110
const res : Token [ ] = [ ] ;
111
111
@@ -114,14 +114,32 @@ export class LlamaChatSession {
114
114
throw new AbortError ( ) ;
115
115
116
116
const tokenStr = this . _ctx . decode ( Uint32Array . from ( [ chunk ] ) ) ;
117
- const { shouldReturn, skipTokenEvent, stopString, stopStringSuffix} = this . _checkStopString ( tokenStr , stopStringIndexes ) ;
117
+ const {
118
+ shouldReturn, skipTokenEvent, stopString, stopStringSuffix
119
+ } = this . _checkStopString ( tokenStr , stopStrings , stopStringIndexes ) ;
120
+
121
+ if ( shouldReturn ) {
122
+ skippedChunksQueue . push ( chunk ) ;
123
+ const skippedChunksText = skippedChunksQueue . length > 0
124
+ ? this . _ctx . decode ( Uint32Array . from ( skippedChunksQueue ) )
125
+ : "" ;
126
+
127
+ const [ queuedTextBeforeStopString ] = skippedChunksText . split ( stopString ) ;
128
+
129
+ if ( queuedTextBeforeStopString . length > 0 ) {
130
+ const beforeStopStringTokens : Token [ ] = Array . from ( this . _ctx . encode ( queuedTextBeforeStopString ) ) ;
131
+
132
+ res . push ( ...beforeStopStringTokens ) ;
133
+ onToken ?.( beforeStopStringTokens ) ;
134
+ skippedChunksQueue . length = 0 ;
135
+ }
118
136
119
- if ( shouldReturn )
120
137
return {
121
138
text : this . _ctx . decode ( Uint32Array . from ( res ) ) ,
122
139
stopString,
123
140
stopStringSuffix
124
141
} ;
142
+ }
125
143
126
144
// if the token is unknown, it means it's not complete character
127
145
if ( tokenStr === UNKNOWN_UNICODE_CHAR || skipTokenEvent ) {
@@ -149,32 +167,31 @@ export class LlamaChatSession {
149
167
} ;
150
168
}
151
169
152
- private _checkStopString ( tokenStr : string , stopStringIndexes : number [ ] ) {
153
- const stopStrings = this . _promptWrapper . getStopStrings ( ) ;
170
+ private _checkStopString ( tokenStr : string , stopStrings : string [ ] , stopStringIndexes : number [ ] ) {
154
171
let skipTokenEvent = false ;
155
172
156
173
for ( let stopStringIndex = 0 ; stopStringIndex < stopStrings . length ; stopStringIndex ++ ) {
157
174
const stopString = stopStrings [ stopStringIndex ] ;
158
175
159
176
let localShouldSkipTokenEvent = false ;
160
- for ( let i = 0 ; i < tokenStr . length && stopStringIndexes [ stopStringIndex ] !== stopString . length ; i ++ ) {
177
+ let i = 0 ;
178
+ for ( ; i < tokenStr . length && stopStringIndexes [ stopStringIndex ] !== stopString . length ; i ++ ) {
161
179
if ( tokenStr [ i ] === stopString [ stopStringIndexes [ stopStringIndex ] ] ) {
162
180
stopStringIndexes [ stopStringIndex ] ++ ;
163
181
localShouldSkipTokenEvent = true ;
164
182
} else {
165
183
stopStringIndexes [ stopStringIndex ] = 0 ;
166
184
localShouldSkipTokenEvent = false ;
167
- break ;
168
185
}
169
186
}
170
187
171
188
if ( stopStringIndexes [ stopStringIndex ] === stopString . length ) {
172
189
return {
173
190
shouldReturn : true ,
174
191
stopString,
175
- stopStringSuffix : tokenStr . length === stopString . length
192
+ stopStringSuffix : tokenStr . length === i
176
193
? null
177
- : tokenStr . slice ( stopString . length )
194
+ : tokenStr . slice ( i )
178
195
} ;
179
196
}
180
197
0 commit comments