@@ -5,6 +5,7 @@ import {ChatPromptWrapper} from "./ChatPromptWrapper.js";
5
5
import { LlamaChatPromptWrapper } from "./chatWrappers/LlamaChatPromptWrapper.js" ;
6
6
import { AbortError } from "./AbortError.js" ;
7
7
8
+ const UNKNOWN_UNICODE_CHAR = "�" ;
8
9
9
10
export class LlamaChatSession {
10
11
private readonly _model : LlamaModel ;
@@ -52,7 +53,7 @@ export class LlamaChatSession {
52
53
} ) ;
53
54
}
54
55
55
- public async prompt ( prompt : string , onToken ?: ( token : number ) => void , { signal} : { signal ?: AbortSignal } = { } ) {
56
+ public async prompt ( prompt : string , onToken ?: ( tokens : number [ ] ) => void , { signal} : { signal ?: AbortSignal } = { } ) {
56
57
if ( ! this . initialized )
57
58
await this . init ( ) ;
58
59
@@ -64,56 +65,70 @@ export class LlamaChatSession {
64
65
} ) ;
65
66
}
66
67
67
- private async _evalTokens ( tokens : Uint32Array , onToken ?: ( token : number ) => void , { signal} : { signal ?: AbortSignal } = { } ) {
68
+ private async _evalTokens ( tokens : Uint32Array , onToken ?: ( tokens : number [ ] ) => void , { signal} : { signal ?: AbortSignal } = { } ) {
69
+ const decodeTokens = ( tokens : number [ ] ) => this . _model . decode ( Uint32Array . from ( tokens ) ) ;
70
+
68
71
const stopStrings = this . _promptWrapper . getStopStrings ( ) ;
69
72
const stopStringIndexes = Array ( stopStrings . length ) . fill ( 0 ) ;
70
73
const skippedChunksQueue : number [ ] = [ ] ;
71
- let res = "" ;
74
+ const res : number [ ] = [ ] ;
75
+
72
76
73
77
for await ( const chunk of this . _model . evaluate ( tokens ) ) {
74
78
if ( signal ?. aborted )
75
79
throw new AbortError ( ) ;
76
80
77
- const tokenStr = this . _model . decode ( Uint32Array . from ( [ chunk ] ) ) ;
78
- let skipTokenEvent = false ;
79
-
80
- for ( let stopStringIndex = 0 ; stopStringIndex < stopStrings . length ; stopStringIndex ++ ) {
81
- const stopString = stopStrings [ stopStringIndex ] ;
82
-
83
- let localShouldSkipTokenEvent = false ;
84
- for ( let i = 0 ; i < tokenStr . length && stopStringIndexes [ stopStringIndex ] !== stopString . length ; i ++ ) {
85
- if ( tokenStr [ i ] === stopString [ stopStringIndexes [ stopStringIndex ] ] ) {
86
- stopStringIndexes [ stopStringIndex ] ++ ;
87
- localShouldSkipTokenEvent = true ;
88
- } else {
89
- stopStringIndexes [ stopStringIndex ] = 0 ;
90
- localShouldSkipTokenEvent = false ;
91
- break ;
92
- }
93
- }
81
+ const tokenStr = decodeTokens ( [ chunk ] ) ;
82
+ const { shouldReturn, skipTokenEvent} = this . _checkStopString ( tokenStr , stopStringIndexes ) ;
83
+
84
+ if ( shouldReturn )
85
+ return decodeTokens ( res ) ;
94
86
95
- if ( stopStringIndexes [ stopStringIndex ] === stopString . length ) {
96
- return res ;
97
- }
87
+ // if the token is unknown, it means it's not complete character
88
+ if ( tokenStr === UNKNOWN_UNICODE_CHAR || skipTokenEvent ) {
89
+ skippedChunksQueue . push ( chunk ) ;
90
+ continue ;
91
+ }
98
92
99
- skipTokenEvent ||= localShouldSkipTokenEvent ;
93
+ if ( skippedChunksQueue . length > 0 ) {
94
+ res . push ( ...skippedChunksQueue ) ;
95
+ onToken ?.( skippedChunksQueue ) ;
96
+ skippedChunksQueue . length = 0 ;
100
97
}
101
98
102
- if ( skipTokenEvent ) {
103
- skippedChunksQueue . push ( chunk ) ;
104
- continue ;
99
+ res . push ( chunk ) ;
100
+ onToken ?.( [ chunk ] ) ;
101
+ }
102
+
103
+ return decodeTokens ( res ) ;
104
+ }
105
+
106
+ private _checkStopString ( tokenStr : string , stopStringIndexes : number [ ] ) {
107
+ const stopStrings = this . _promptWrapper . getStopStrings ( ) ;
108
+ let skipTokenEvent = false ;
109
+
110
+ for ( let stopStringIndex = 0 ; stopStringIndex < stopStrings . length ; stopStringIndex ++ ) {
111
+ const stopString = stopStrings [ stopStringIndex ] ;
112
+
113
+ let localShouldSkipTokenEvent = false ;
114
+ for ( let i = 0 ; i < tokenStr . length && stopStringIndexes [ stopStringIndex ] !== stopString . length ; i ++ ) {
115
+ if ( tokenStr [ i ] === stopString [ stopStringIndexes [ stopStringIndex ] ] ) {
116
+ stopStringIndexes [ stopStringIndex ] ++ ;
117
+ localShouldSkipTokenEvent = true ;
118
+ } else {
119
+ stopStringIndexes [ stopStringIndex ] = 0 ;
120
+ localShouldSkipTokenEvent = false ;
121
+ break ;
122
+ }
105
123
}
106
124
107
- while ( skippedChunksQueue . length > 0 ) {
108
- const token = skippedChunksQueue . shift ( ) ! ;
109
- res += this . _model . decode ( Uint32Array . from ( [ token ] ) ) ;
110
- onToken ?.( token ) ;
125
+ if ( stopStringIndexes [ stopStringIndex ] === stopString . length ) {
126
+ return { shouldReturn : true } ;
111
127
}
112
128
113
- res += tokenStr ;
114
- onToken ?.( chunk ) ;
129
+ skipTokenEvent ||= localShouldSkipTokenEvent ;
115
130
}
116
131
117
- return res ;
132
+ return { skipTokenEvent } ;
118
133
}
119
134
}
0 commit comments