Skip to content

Commit 9cab784

Browse files
authored
fix: long LlamaText tokenization (#249)
1 parent c89178f commit 9cab784

File tree

9 files changed

+53
-32
lines changed

9 files changed

+53
-32
lines changed

src/evaluator/LlamaChat/LlamaChat.ts

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import {resolveChatWrapper} from "../../chatWrappers/utils/resolveChatWrapper.js
1818
import {GeneralChatWrapper} from "../../chatWrappers/GeneralChatWrapper.js";
1919
import {TokenBias} from "../TokenBias.js";
2020
import {safeEventCallback} from "../../utils/safeEventCallback.js";
21+
import {pushAll} from "../../utils/pushAll.js";
2122
import {
2223
eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy
2324
} from "./utils/contextShiftStrategies/eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy.js";
@@ -1491,7 +1492,7 @@ class GenerateResponseState<const Functions extends ChatModelFunctions | undefin
14911492
})
14921493
.flat(1);
14931494
this.pendingTokens.length = 0;
1494-
this.pendingTokens.push(...newPendingTokens);
1495+
pushAll(this.pendingTokens, newPendingTokens);
14951496
this.removedStartTextToIgnore = true;
14961497
}
14971498
}
@@ -1975,7 +1976,7 @@ class GenerateResponseState<const Functions extends ChatModelFunctions | undefin
19751976

19761977
this.stopGenerationDetector.clearInProgressStops();
19771978
this.customStopGenerationTriggersDetector.clearInProgressStops();
1978-
this.pendingTokens.push(...this.streamRegulator.popFreeChunkTokens());
1979+
pushAll(this.pendingTokens, this.streamRegulator.popFreeChunkTokens());
19791980

19801981
const triggeredStops = this.functionSyntaxStartDetector.getTriggeredStops();
19811982
const partiallyFreeTokens = this.streamRegulator.getPartiallyFreeChunk(this.llamaChat.model.tokenizer);
@@ -1984,15 +1985,15 @@ class GenerateResponseState<const Functions extends ChatModelFunctions | undefin
19841985
partiallyFreeTokens,
19851986
this.llamaChat.model.tokenizer
19861987
);
1987-
this.pendingTokens.push(...queuedTokensBeforeStopTrigger);
1988+
pushAll(this.pendingTokens, queuedTokensBeforeStopTrigger);
19881989

19891990
this.removeFoundStartIgnoreTextsFromPendingTokens(true);
19901991

19911992
if (this.pendingTokens.length > 0)
19921993
this.onToken?.(this.pendingTokens.slice());
19931994

1994-
this.res.push(...this.pendingTokens);
1995-
this.contextWindowsRes.push(...this.pendingTokens);
1995+
pushAll(this.res, this.pendingTokens);
1996+
pushAll(this.contextWindowsRes, this.pendingTokens);
19961997
this.pendingTokens.length = 0;
19971998

19981999
this.streamRegulator.clearQueue();
@@ -2192,7 +2193,7 @@ class GenerateResponseState<const Functions extends ChatModelFunctions | undefin
21922193
this.customStopGenerationTriggersDetector.clearTriggeredStops();
21932194
this.customStopGenerationTriggersDetector.clearInProgressStops();
21942195

2195-
this.pendingTokens.push(...this.streamRegulator.popFreeChunkTokens());
2196+
pushAll(this.pendingTokens, this.streamRegulator.popFreeChunkTokens());
21962197

21972198
const triggeredStops = this.functionSyntaxStartDetector.getTriggeredStops();
21982199
const partiallyFreeTokens = this.streamRegulator.getPartiallyFreeChunk(this.llamaChat.model.tokenizer);
@@ -2202,7 +2203,7 @@ class GenerateResponseState<const Functions extends ChatModelFunctions | undefin
22022203
partiallyFreeTokens,
22032204
this.llamaChat.model.tokenizer
22042205
);
2205-
this.pendingTokens.push(...queuedTokensBeforeStopTrigger);
2206+
pushAll(this.pendingTokens, queuedTokensBeforeStopTrigger);
22062207

22072208
const firstRemainingGenerationAfterStop = StopGenerationDetector.getFirstRemainingGenerationAfterStop(triggeredStops);
22082209
const remainingTextAfterStop = StopGenerationDetector.detokenizeRemainingGeneration(
@@ -2228,7 +2229,7 @@ class GenerateResponseState<const Functions extends ChatModelFunctions | undefin
22282229
}
22292230

22302231
public popStreamRegulatorFreeTokens() {
2231-
this.pendingTokens.push(...this.streamRegulator.popFreeChunkTokens());
2232+
pushAll(this.pendingTokens, this.streamRegulator.popFreeChunkTokens());
22322233
}
22332234

22342235
public handleStopGenerationTrigger(lastHistoryItemType: "user" | "model") {
@@ -2237,7 +2238,7 @@ class GenerateResponseState<const Functions extends ChatModelFunctions | undefin
22372238
) {
22382239
this.stopGenerationDetector.clearInProgressStops();
22392240
this.customStopGenerationTriggersDetector.clearInProgressStops();
2240-
this.pendingTokens.push(...this.streamRegulator.popFreeChunkTokens());
2241+
pushAll(this.pendingTokens, this.streamRegulator.popFreeChunkTokens());
22412242

22422243
const triggeredStops = this.stopGenerationDetector.hasTriggeredStops
22432244
? this.stopGenerationDetector.getTriggeredStops()
@@ -2250,7 +2251,7 @@ class GenerateResponseState<const Functions extends ChatModelFunctions | undefin
22502251
partiallyFreeTokens,
22512252
this.llamaChat.model.tokenizer
22522253
);
2253-
this.pendingTokens.push(...queuedTokensBeforeStopTrigger);
2254+
pushAll(this.pendingTokens, queuedTokensBeforeStopTrigger);
22542255

22552256
const firstRemainingGenerationAfterStop = StopGenerationDetector.getFirstRemainingGenerationAfterStop(triggeredStops);
22562257

@@ -2259,8 +2260,8 @@ class GenerateResponseState<const Functions extends ChatModelFunctions | undefin
22592260
if (this.pendingTokens.length > 0)
22602261
this.onToken?.(this.pendingTokens.slice());
22612262

2262-
this.res.push(...this.pendingTokens);
2263-
this.contextWindowsRes.push(...this.pendingTokens);
2263+
pushAll(this.res, this.pendingTokens);
2264+
pushAll(this.contextWindowsRes, this.pendingTokens);
22642265
this.pendingTokens.length = 0;
22652266

22662267
let modelResponse = this.llamaChat.model.detokenize(this.res);
@@ -2336,8 +2337,8 @@ class GenerateResponseState<const Functions extends ChatModelFunctions | undefin
23362337

23372338
if (this.pendingTokens.length > 0) {
23382339
this.onToken?.(this.pendingTokens.slice());
2339-
this.res.push(...this.pendingTokens);
2340-
this.contextWindowsRes.push(...this.pendingTokens);
2340+
pushAll(this.res, this.pendingTokens);
2341+
pushAll(this.contextWindowsRes, this.pendingTokens);
23412342
this.pendingTokens.length = 0;
23422343
}
23432344
}

src/evaluator/LlamaChatSession/utils/LlamaChatSessionPromptCompletionEngine.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import {Token} from "../../../types.js";
33
import {getConsoleLogPrefix} from "../../../utils/getConsoleLogPrefix.js";
44
import {LruCache} from "../../../utils/LruCache.js";
55
import {safeEventCallback} from "../../../utils/safeEventCallback.js";
6+
import {pushAll} from "../../../utils/pushAll.js";
67
import type {LLamaChatCompletePromptOptions, LlamaChatSession} from "../LlamaChatSession.js";
78

89
export type LLamaChatPromptCompletionEngineOptions = {
@@ -146,7 +147,7 @@ export class LlamaChatSessionPromptCompletionEngine {
146147
maxTokens: leftTokens,
147148
signal: currentAbortSignal,
148149
onToken: (chunk) => {
149-
currentCompletion.push(...chunk);
150+
pushAll(currentCompletion, chunk);
150151
const completion = (existingCompletion ?? "") + this._chatSession.model.detokenize(currentCompletion);
151152
completionCache.putCompletion(prompt, completion);
152153

src/evaluator/LlamaCompletion.ts

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import {StopGenerationDetector} from "../utils/StopGenerationDetector.js";
99
import {UNKNOWN_UNICODE_CHAR} from "../consts.js";
1010
import {getQueuedTokensBeforeStopTrigger} from "../utils/getQueuedTokensBeforeStopTrigger.js";
1111
import {safeEventCallback} from "../utils/safeEventCallback.js";
12+
import {pushAll} from "../utils/pushAll.js";
1213
import {LlamaGrammarEvaluationState} from "./LlamaGrammarEvaluationState.js";
1314
import {LlamaGrammar} from "./LlamaGrammar.js";
1415
import {EvaluationPriority} from "./LlamaContext/types.js";
@@ -248,7 +249,7 @@ export class LlamaCompletion {
248249
throw new Error("The context size is too small to generate a response for the given input");
249250

250251
const slicedTokens = tokens.slice(-inputTokensSize);
251-
res.push(...slicedTokens);
252+
pushAll(res, slicedTokens);
252253

253254
return res;
254255
}
@@ -428,10 +429,10 @@ export class LlamaCompletion {
428429
newContextState.push(bosToken);
429430

430431
newContextState.push(prefixToken);
431-
newContextState.push(...resolvedPrefixTokens);
432+
pushAll(newContextState, resolvedPrefixTokens);
432433

433434
newContextState.push(suffixToken);
434-
newContextState.push(...resolvedSuffixTokens);
435+
pushAll(newContextState, resolvedSuffixTokens);
435436

436437
newContextState.push(middleToken);
437438

@@ -655,7 +656,7 @@ export class LlamaCompletion {
655656
stopGenerationDetector.recordGeneration({text, tokens, queuedTokenRelease});
656657
customStopGenerationTriggersDetector.recordGeneration({text, tokens, queuedTokenRelease});
657658

658-
pendingTokens.push(...streamRegulator.popFreeChunkTokens());
659+
pushAll(pendingTokens, streamRegulator.popFreeChunkTokens());
659660

660661
if (stopGenerationDetector.hasTriggeredStops || customStopGenerationTriggersDetector.hasTriggeredStops ||
661662
model.isEogToken(token)
@@ -670,14 +671,14 @@ export class LlamaCompletion {
670671
partiallyFreeTokens,
671672
model.tokenizer
672673
);
673-
pendingTokens.push(...queuedTokensBeforeStopTrigger);
674+
pushAll(pendingTokens, queuedTokensBeforeStopTrigger);
674675

675676
const firstRemainingGenerationAfterStop = StopGenerationDetector.getFirstRemainingGenerationAfterStop(triggeredStops);
676677

677678
if (pendingTokens.length > 0)
678679
onToken?.(pendingTokens.slice());
679680

680-
res.push(...pendingTokens);
681+
pushAll(res, pendingTokens);
681682
pendingTokens.length = 0;
682683

683684
let modelResponse = model.detokenize(res);
@@ -710,7 +711,7 @@ export class LlamaCompletion {
710711

711712
if (pendingTokens.length > 0) {
712713
onToken?.(pendingTokens.slice());
713-
res.push(...pendingTokens);
714+
pushAll(res, pendingTokens);
714715
pendingTokens.length = 0;
715716
}
716717

src/utils/LlamaText.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import {pushAll} from "./pushAll.js";
12
import type {InspectOptions, inspect as InspectFunction} from "node:util";
23
import type {Token, Tokenizer} from "../types.js";
34

@@ -52,7 +53,7 @@ class LlamaText {
5253

5354
if (i !== this.values.length - 1) {
5455
if (isLlamaText(separator))
55-
newValues.push(...separator.values);
56+
pushAll(newValues, separator.values);
5657
else
5758
newValues.push(separator);
5859
}
@@ -98,16 +99,18 @@ class LlamaText {
9899

99100
for (const value of this.values) {
100101
if (value instanceof SpecialToken) {
101-
res.push(...tokenizer(textToTokenize, false, resolveTokenizerOptions()), ...value.tokenize(tokenizer));
102+
pushAll(res, tokenizer(textToTokenize, false, resolveTokenizerOptions()));
103+
pushAll(res, value.tokenize(tokenizer));
102104
textToTokenize = "";
103105
} else if (value instanceof SpecialTokensText) {
104-
res.push(...tokenizer(textToTokenize, false, resolveTokenizerOptions()), ...value.tokenize(tokenizer, hasContent() || options === "trimLeadingSpace"));
106+
pushAll(res, tokenizer(textToTokenize, false, resolveTokenizerOptions()));
107+
pushAll(res, value.tokenize(tokenizer, hasContent() || options === "trimLeadingSpace"));
105108
textToTokenize = "";
106109
} else
107110
textToTokenize += value;
108111
}
109112

110-
res.push(...tokenizer(textToTokenize, false, resolveTokenizerOptions()));
113+
pushAll(res, tokenizer(textToTokenize, false, resolveTokenizerOptions()));
111114

112115
return res;
113116
}

src/utils/TokenStreamRegulator.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import {DisposedError} from "lifecycle-utils";
22
import {Token, Tokenizer} from "../types.js";
3+
import {pushAll} from "./pushAll.js";
34

45
export class TokenStreamRegulator {
56
/** @internal */ private readonly _queue: QueuedTokenRelease[] = [];
@@ -16,7 +17,7 @@ export class TokenStreamRegulator {
1617
const res: Token[] = [];
1718

1819
while (this._queue.length > 0 && this._queue[0].isFree)
19-
res.push(...this._queue.shift()!.tokens);
20+
pushAll(res, this._queue.shift()!.tokens);
2021

2122
return res;
2223
}
@@ -60,7 +61,7 @@ export class TokenStreamRegulator {
6061
if (resTokensText.length + tokenText.length > text.length) {
6162
const remainingText = text.slice(resTokensText.length);
6263
const remainingTokens = tokenizer(remainingText, false, "trimLeadingSpace");
63-
resTokens.push(...remainingTokens);
64+
pushAll(resTokens, remainingTokens);
6465
break;
6566
}
6667

src/utils/pushAll.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
/**
2+
* Pushes all items from the given array or set to the given array.
3+
* @param array - The array to push the items to
4+
* @param items - The items to push to the array
5+
*/
6+
export function pushAll<T>(array: T[], items: readonly T[] | ReadonlySet<T>): T[] {
7+
for (const item of items)
8+
array.push(item);
9+
10+
return array;
11+
}

templates/electron-typescript-react/electron/state/llmState.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,9 @@ export const llmFunctions = {
343343
signal: promptAbortController.signal,
344344
stopOnAbortSignal: true,
345345
onToken(chunk) {
346-
inProgressResponse.push(...chunk);
346+
for (const token of chunk)
347+
inProgressResponse.push(token);
348+
347349
llmState.state = {
348350
...llmState.state,
349351
chatSession: {

test/modelDependent/llama3/chatSession.test.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import {describe, expect, test} from "vitest";
22
import {Llama3ChatWrapper, LlamaChatSession, Token} from "../../../src/index.js";
33
import {getModelFile} from "../../utils/modelFiles.js";
44
import {getTestLlama} from "../../utils/getTestLlama.js";
5+
import {pushAll} from "../../../src/utils/pushAll.js";
56

67
describe("llama 3", () => {
78
describe("chat session", () => {
@@ -27,7 +28,7 @@ describe("llama 3", () => {
2728
signal: abortController.signal,
2829
stopOnAbortSignal: true,
2930
onToken(chunk) {
30-
tokens.push(...chunk);
31+
pushAll(tokens, chunk);
3132

3233
if (tokens.length >= 2)
3334
abortController.abort();

test/modelDependent/llama3/grammar.test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ describe("llama 3", () => {
2020
contextSequence: context.getSequence()
2121
});
2222

23-
const grammar = new LlamaJsonSchemaGrammar(llama, {
23+
const grammar = await llama.createGrammarForJsonSchema({
2424
type: "object",
2525
properties: {
2626
"userMessagePositivityScoreFromOneToTen": {
@@ -33,7 +33,7 @@ describe("llama 3", () => {
3333
}
3434
}
3535
}
36-
} as const);
36+
});
3737

3838
const res = await chatSession.prompt("It's great!", {
3939
grammar

0 commit comments

Comments
 (0)