Skip to content

Commit 9bdef11

Browse files
authored
fix: handle stop words remainder properly in a chat session (#32)
1 parent dd49959 commit 9bdef11

File tree

2 files changed

+31
-11
lines changed

2 files changed

+31
-11
lines changed

src/llamaEvaluator/LlamaChatSession.ts

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ export class LlamaChatSession {
105105
onToken, signal, maxTokens
106106
}: { onToken?(tokens: Token[]): void, signal?: AbortSignal, maxTokens?: number } = {}) {
107107
const stopStrings = this._promptWrapper.getStopStrings();
108-
const stopStringIndexes = Array(stopStrings.length).fill(0);
108+
const stopStringIndexes: number[] = Array(stopStrings.length).fill(0);
109109
const skippedChunksQueue: Token[] = [];
110110
const res: Token[] = [];
111111

@@ -114,14 +114,32 @@ export class LlamaChatSession {
114114
throw new AbortError();
115115

116116
const tokenStr = this._ctx.decode(Uint32Array.from([chunk]));
117-
const {shouldReturn, skipTokenEvent, stopString, stopStringSuffix} = this._checkStopString(tokenStr, stopStringIndexes);
117+
const {
118+
shouldReturn, skipTokenEvent, stopString, stopStringSuffix
119+
} = this._checkStopString(tokenStr, stopStrings, stopStringIndexes);
120+
121+
if (shouldReturn) {
122+
skippedChunksQueue.push(chunk);
123+
const skippedChunksText = skippedChunksQueue.length > 0
124+
? this._ctx.decode(Uint32Array.from(skippedChunksQueue))
125+
: "";
126+
127+
const [queuedTextBeforeStopString] = skippedChunksText.split(stopString);
128+
129+
if (queuedTextBeforeStopString.length > 0) {
130+
const beforeStopStringTokens: Token[] = Array.from(this._ctx.encode(queuedTextBeforeStopString));
131+
132+
res.push(...beforeStopStringTokens);
133+
onToken?.(beforeStopStringTokens);
134+
skippedChunksQueue.length = 0;
135+
}
118136

119-
if (shouldReturn)
120137
return {
121138
text: this._ctx.decode(Uint32Array.from(res)),
122139
stopString,
123140
stopStringSuffix
124141
};
142+
}
125143

126144
// if the token is unknown, it means it's not complete character
127145
if (tokenStr === UNKNOWN_UNICODE_CHAR || skipTokenEvent) {
@@ -149,32 +167,31 @@ export class LlamaChatSession {
149167
};
150168
}
151169

152-
private _checkStopString(tokenStr: string, stopStringIndexes: number[]){
153-
const stopStrings = this._promptWrapper.getStopStrings();
170+
private _checkStopString(tokenStr: string, stopStrings: string[], stopStringIndexes: number[]){
154171
let skipTokenEvent = false;
155172

156173
for (let stopStringIndex = 0; stopStringIndex < stopStrings.length; stopStringIndex++) {
157174
const stopString = stopStrings[stopStringIndex];
158175

159176
let localShouldSkipTokenEvent = false;
160-
for (let i = 0; i < tokenStr.length && stopStringIndexes[stopStringIndex] !== stopString.length; i++) {
177+
let i = 0;
178+
for (; i < tokenStr.length && stopStringIndexes[stopStringIndex] !== stopString.length; i++) {
161179
if (tokenStr[i] === stopString[stopStringIndexes[stopStringIndex]]) {
162180
stopStringIndexes[stopStringIndex]++;
163181
localShouldSkipTokenEvent = true;
164182
} else {
165183
stopStringIndexes[stopStringIndex] = 0;
166184
localShouldSkipTokenEvent = false;
167-
break;
168185
}
169186
}
170187

171188
if (stopStringIndexes[stopStringIndex] === stopString.length) {
172189
return {
173190
shouldReturn: true,
174191
stopString,
175-
stopStringSuffix: tokenStr.length === stopString.length
192+
stopStringSuffix: tokenStr.length === i
176193
? null
177-
: tokenStr.slice(stopString.length)
194+
: tokenStr.slice(i)
178195
};
179196
}
180197

src/llamaEvaluator/LlamaContext.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,14 @@ export class LlamaContext {
2929
return this._ctx.encode(text);
3030
}
3131

32-
public decode(tokens: Uint32Array): string {
32+
public decode(tokens: Uint32Array | Token[]): string {
3333
if (tokens.length === 0)
3434
return "";
3535

36-
return this._ctx.decode(tokens);
36+
if (tokens instanceof Uint32Array)
37+
return this._ctx.decode(tokens);
38+
39+
return this._ctx.decode(Uint32Array.from(tokens));
3740
}
3841

3942
public get prependBos() {

0 commit comments

Comments
 (0)