@@ -105,7 +105,7 @@ export class LlamaChatSession {
105105 onToken, signal, maxTokens
106106 } : { onToken ?( tokens : Token [ ] ) : void , signal ?: AbortSignal , maxTokens ?: number } = { } ) {
107107 const stopStrings = this . _promptWrapper . getStopStrings ( ) ;
108- const stopStringIndexes = Array ( stopStrings . length ) . fill ( 0 ) ;
108+ const stopStringIndexes : number [ ] = Array ( stopStrings . length ) . fill ( 0 ) ;
109109 const skippedChunksQueue : Token [ ] = [ ] ;
110110 const res : Token [ ] = [ ] ;
111111
@@ -114,14 +114,32 @@ export class LlamaChatSession {
114114 throw new AbortError ( ) ;
115115
116116 const tokenStr = this . _ctx . decode ( Uint32Array . from ( [ chunk ] ) ) ;
117- const { shouldReturn, skipTokenEvent, stopString, stopStringSuffix} = this . _checkStopString ( tokenStr , stopStringIndexes ) ;
117+ const {
118+ shouldReturn, skipTokenEvent, stopString, stopStringSuffix
119+ } = this . _checkStopString ( tokenStr , stopStrings , stopStringIndexes ) ;
120+
121+ if ( shouldReturn ) {
122+ skippedChunksQueue . push ( chunk ) ;
123+ const skippedChunksText = skippedChunksQueue . length > 0
124+ ? this . _ctx . decode ( Uint32Array . from ( skippedChunksQueue ) )
125+ : "" ;
126+
127+ const [ queuedTextBeforeStopString ] = skippedChunksText . split ( stopString ) ;
128+
129+ if ( queuedTextBeforeStopString . length > 0 ) {
130+ const beforeStopStringTokens : Token [ ] = Array . from ( this . _ctx . encode ( queuedTextBeforeStopString ) ) ;
131+
132+ res . push ( ...beforeStopStringTokens ) ;
133+ onToken ?.( beforeStopStringTokens ) ;
134+ skippedChunksQueue . length = 0 ;
135+ }
118136
119- if ( shouldReturn )
120137 return {
121138 text : this . _ctx . decode ( Uint32Array . from ( res ) ) ,
122139 stopString,
123140 stopStringSuffix
124141 } ;
142+ }
125143
126144 // if the token is unknown, it means it's not complete character
127145 if ( tokenStr === UNKNOWN_UNICODE_CHAR || skipTokenEvent ) {
@@ -149,32 +167,31 @@ export class LlamaChatSession {
149167 } ;
150168 }
151169
152- private _checkStopString ( tokenStr : string , stopStringIndexes : number [ ] ) {
153- const stopStrings = this . _promptWrapper . getStopStrings ( ) ;
170+ private _checkStopString ( tokenStr : string , stopStrings : string [ ] , stopStringIndexes : number [ ] ) {
154171 let skipTokenEvent = false ;
155172
156173 for ( let stopStringIndex = 0 ; stopStringIndex < stopStrings . length ; stopStringIndex ++ ) {
157174 const stopString = stopStrings [ stopStringIndex ] ;
158175
159176 let localShouldSkipTokenEvent = false ;
160- for ( let i = 0 ; i < tokenStr . length && stopStringIndexes [ stopStringIndex ] !== stopString . length ; i ++ ) {
177+ let i = 0 ;
178+ for ( ; i < tokenStr . length && stopStringIndexes [ stopStringIndex ] !== stopString . length ; i ++ ) {
161179 if ( tokenStr [ i ] === stopString [ stopStringIndexes [ stopStringIndex ] ] ) {
162180 stopStringIndexes [ stopStringIndex ] ++ ;
163181 localShouldSkipTokenEvent = true ;
164182 } else {
165183 stopStringIndexes [ stopStringIndex ] = 0 ;
166184 localShouldSkipTokenEvent = false ;
167- break ;
168185 }
169186 }
170187
171188 if ( stopStringIndexes [ stopStringIndex ] === stopString . length ) {
172189 return {
173190 shouldReturn : true ,
174191 stopString,
175- stopStringSuffix : tokenStr . length === stopString . length
192+ stopStringSuffix : tokenStr . length === i
176193 ? null
177- : tokenStr . slice ( stopString . length )
194+ : tokenStr . slice ( i )
178195 } ;
179196 }
180197
0 commit comments