Skip to content

Commit cf0d4e6

Browse files
authored
feat(model): support emoji (#1)
1 parent 6dd0d56 commit cf0d4e6

File tree

2 files changed

+50
-35
lines changed

2 files changed

+50
-35
lines changed

src/LlamaChatSession.ts

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import {ChatPromptWrapper} from "./ChatPromptWrapper.js";
55
import {LlamaChatPromptWrapper} from "./chatWrappers/LlamaChatPromptWrapper.js";
66
import {AbortError} from "./AbortError.js";
77

8+
const UNKNOWN_UNICODE_CHAR = "�";
89

910
export class LlamaChatSession {
1011
private readonly _model: LlamaModel;
@@ -52,7 +53,7 @@ export class LlamaChatSession {
5253
});
5354
}
5455

55-
public async prompt(prompt: string, onToken?: (token: number) => void, {signal}: {signal?: AbortSignal} = {}) {
56+
public async prompt(prompt: string, onToken?: (tokens: number[]) => void, {signal}: { signal?: AbortSignal } = {}) {
5657
if (!this.initialized)
5758
await this.init();
5859

@@ -64,56 +65,70 @@ export class LlamaChatSession {
6465
});
6566
}
6667

67-
private async _evalTokens(tokens: Uint32Array, onToken?: (token: number) => void, {signal}: {signal?: AbortSignal} = {}) {
68+
private async _evalTokens(tokens: Uint32Array, onToken?: (tokens: number[]) => void, {signal}: { signal?: AbortSignal } = {}) {
69+
const decodeTokens = (tokens: number[]) => this._model.decode(Uint32Array.from(tokens));
70+
6871
const stopStrings = this._promptWrapper.getStopStrings();
6972
const stopStringIndexes = Array(stopStrings.length).fill(0);
7073
const skippedChunksQueue: number[] = [];
71-
let res = "";
74+
const res: number[] = [];
75+
7276

7377
for await (const chunk of this._model.evaluate(tokens)) {
7478
if (signal?.aborted)
7579
throw new AbortError();
7680

77-
const tokenStr = this._model.decode(Uint32Array.from([chunk]));
78-
let skipTokenEvent = false;
79-
80-
for (let stopStringIndex = 0; stopStringIndex < stopStrings.length; stopStringIndex++) {
81-
const stopString = stopStrings[stopStringIndex];
82-
83-
let localShouldSkipTokenEvent = false;
84-
for (let i = 0; i < tokenStr.length && stopStringIndexes[stopStringIndex] !== stopString.length; i++) {
85-
if (tokenStr[i] === stopString[stopStringIndexes[stopStringIndex]]) {
86-
stopStringIndexes[stopStringIndex]++;
87-
localShouldSkipTokenEvent = true;
88-
} else {
89-
stopStringIndexes[stopStringIndex] = 0;
90-
localShouldSkipTokenEvent = false;
91-
break;
92-
}
93-
}
81+
const tokenStr = decodeTokens([chunk]);
82+
const {shouldReturn, skipTokenEvent} = this._checkStopString(tokenStr, stopStringIndexes);
83+
84+
if (shouldReturn)
85+
return decodeTokens(res);
9486

95-
if (stopStringIndexes[stopStringIndex] === stopString.length) {
96-
return res;
97-
}
87+
// if the token is unknown, it means it's not complete character
88+
if (tokenStr === UNKNOWN_UNICODE_CHAR || skipTokenEvent) {
89+
skippedChunksQueue.push(chunk);
90+
continue;
91+
}
9892

99-
skipTokenEvent ||= localShouldSkipTokenEvent;
93+
if (skippedChunksQueue.length > 0) {
94+
res.push(...skippedChunksQueue);
95+
onToken?.(skippedChunksQueue);
96+
skippedChunksQueue.length = 0;
10097
}
10198

102-
if (skipTokenEvent) {
103-
skippedChunksQueue.push(chunk);
104-
continue;
99+
res.push(chunk);
100+
onToken?.([chunk]);
101+
}
102+
103+
return decodeTokens(res);
104+
}
105+
106+
private _checkStopString(tokenStr: string, stopStringIndexes: number[]){
107+
const stopStrings = this._promptWrapper.getStopStrings();
108+
let skipTokenEvent = false;
109+
110+
for (let stopStringIndex = 0; stopStringIndex < stopStrings.length; stopStringIndex++) {
111+
const stopString = stopStrings[stopStringIndex];
112+
113+
let localShouldSkipTokenEvent = false;
114+
for (let i = 0; i < tokenStr.length && stopStringIndexes[stopStringIndex] !== stopString.length; i++) {
115+
if (tokenStr[i] === stopString[stopStringIndexes[stopStringIndex]]) {
116+
stopStringIndexes[stopStringIndex]++;
117+
localShouldSkipTokenEvent = true;
118+
} else {
119+
stopStringIndexes[stopStringIndex] = 0;
120+
localShouldSkipTokenEvent = false;
121+
break;
122+
}
105123
}
106124

107-
while (skippedChunksQueue.length > 0) {
108-
const token = skippedChunksQueue.shift()!;
109-
res += this._model.decode(Uint32Array.from([token]));
110-
onToken?.(token);
125+
if (stopStringIndexes[stopStringIndex] === stopString.length) {
126+
return {shouldReturn: true};
111127
}
112128

113-
res += tokenStr;
114-
onToken?.(chunk);
129+
skipTokenEvent ||= localShouldSkipTokenEvent;
115130
}
116131

117-
return res;
132+
return {skipTokenEvent};
118133
}
119134
}

src/cli/commands/ChatCommand.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ async function RunChat({model: modelArg, systemInfo, systemPrompt}: ChatCommand)
9393

9494
process.stdout.write(startColor);
9595
await session.prompt(input, (chunk) => {
96-
process.stdout.write(model.decode(Uint32Array.from([chunk])));
96+
process.stdout.write(model.decode(Uint32Array.from(chunk)));
9797
});
9898
process.stdout.write(endColor);
9999
console.log();

0 commit comments

Comments
 (0)