Skip to content

Commit ed4deba

Browse files
authored
Fix/chat context (#3)
1 parent b9a2a61 commit ed4deba

File tree

6 files changed

+65
-41
lines changed

6 files changed

+65
-41
lines changed

src/cli/commands/ChatCommand.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
6060

6161

6262
async function RunChat({model: modelArg, systemInfo, systemPrompt, wrapper}: ChatCommand) {
63-
const {LlamaChatSession} = await import("../../LlamaChatSession.js");
64-
const {LlamaModel} = await import("../../LlamaModel.js");
63+
const {LlamaChatSession} = await import("../../llamaEvaluator/LlamaChatSession.js");
64+
const {LlamaModel} = await import("../../llamaEvaluator/LlamaModel.js");
6565

6666
const model = new LlamaModel({
6767
modelPath: modelArg
6868
});
6969
const session = new LlamaChatSession({
70-
model,
70+
context: model.createContext(),
7171
printLLamaSystemInfo: systemInfo,
7272
systemPrompt,
7373
promptWrapper: createChatWrapper(wrapper)
@@ -102,7 +102,7 @@ async function RunChat({model: modelArg, systemInfo, systemPrompt, wrapper}: Cha
102102

103103
process.stdout.write(startColor);
104104
await session.prompt(input, (chunk) => {
105-
process.stdout.write(model.decode(Uint32Array.from(chunk)));
105+
process.stdout.write(session.context.decode(Uint32Array.from(chunk)));
106106
});
107107
process.stdout.write(endColor);
108108
console.log();

src/index.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
import {LlamaChatSession} from "./LlamaChatSession.js";
2-
import {LlamaModel} from "./LlamaModel.js";
1+
import {LlamaChatSession} from "./llamaEvaluator/LlamaChatSession.js";
2+
import {LlamaModel} from "./llamaEvaluator/LlamaModel.js";
33
import {AbortError} from "./AbortError.js";
44
import {ChatPromptWrapper} from "./ChatPromptWrapper.js";
55
import {EmptyChatPromptWrapper} from "./chatWrappers/EmptyChatPromptWrapper.js";
66
import {LlamaChatPromptWrapper} from "./chatWrappers/LlamaChatPromptWrapper.js";
77
import {GeneralChatPromptWrapper} from "./chatWrappers/GeneralChatPromptWrapper.js";
8+
import {LlamaContext} from "./llamaEvaluator/LlamaContext.js";
9+
810

911
export {
1012
LlamaModel,
13+
LlamaContext,
1114
LlamaChatSession,
1215
AbortError,
1316
ChatPromptWrapper,

src/llamaEvaluator/LlamaBins.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import {loadBin, type LLAMAModel, type LLAMAContext} from "../utils/getBin.js";
2+
3+
export const llamaCppNode = await loadBin();
4+
const {LLAMAModel, LLAMAContext} = llamaCppNode;
5+
6+
export {LLAMAModel, LLAMAContext};

src/LlamaChatSession.ts renamed to src/llamaEvaluator/LlamaChatSession.ts

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,33 @@
1-
import {defaultChatSystemPrompt} from "./config.js";
2-
import {withLock} from "./utils/withLock.js";
1+
import {defaultChatSystemPrompt} from "../config.js";
2+
import {withLock} from "../utils/withLock.js";
3+
import {ChatPromptWrapper} from "../ChatPromptWrapper.js";
4+
import {AbortError} from "../AbortError.js";
5+
import {GeneralChatPromptWrapper} from "../chatWrappers/GeneralChatPromptWrapper.js";
36
import {LlamaModel} from "./LlamaModel.js";
4-
import {ChatPromptWrapper} from "./ChatPromptWrapper.js";
5-
import {AbortError} from "./AbortError.js";
6-
import {GeneralChatPromptWrapper} from "./chatWrappers/GeneralChatPromptWrapper.js";
7+
import {LlamaContext} from "./LlamaContext.js";
78

89
const UNKNOWN_UNICODE_CHAR = "�";
910

1011
export class LlamaChatSession {
11-
private readonly _model: LlamaModel;
1212
private readonly _systemPrompt: string;
1313
private readonly _printLLamaSystemInfo: boolean;
1414
private readonly _promptWrapper: ChatPromptWrapper;
1515
private _promptIndex: number = 0;
1616
private _initialized: boolean = false;
17+
private readonly _ctx: LlamaContext;
1718

1819
public constructor({
19-
model,
20+
context,
2021
printLLamaSystemInfo = false,
2122
promptWrapper = new GeneralChatPromptWrapper(),
2223
systemPrompt = defaultChatSystemPrompt
2324
}: {
24-
model: LlamaModel,
25+
context: LlamaContext,
2526
printLLamaSystemInfo?: boolean,
2627
promptWrapper?: ChatPromptWrapper,
2728
systemPrompt?: string,
2829
}) {
29-
this._model = model;
30+
this._ctx = context;
3031
this._printLLamaSystemInfo = printLLamaSystemInfo;
3132
this._promptWrapper = promptWrapper;
3233

@@ -37,8 +38,8 @@ export class LlamaChatSession {
3738
return this._initialized;
3839
}
3940

40-
public get model() {
41-
return this._model;
41+
public get context() {
42+
return this._ctx;
4243
}
4344

4445
public async init() {
@@ -47,7 +48,7 @@ export class LlamaChatSession {
4748
return;
4849

4950
if (this._printLLamaSystemInfo)
50-
console.log("Llama system info", this._model.systemInfo);
51+
console.log("Llama system info", LlamaModel.systemInfo);
5152

5253
this._initialized = true;
5354
});
@@ -61,20 +62,20 @@ export class LlamaChatSession {
6162
const promptText = this._promptWrapper.wrapPrompt(prompt, {systemPrompt: this._systemPrompt, promptIndex: this._promptIndex});
6263
this._promptIndex++;
6364

64-
return await this._evalTokens(this._model.encode(promptText), onToken, {signal});
65+
return await this._evalTokens(this._ctx.encode(promptText), onToken, {signal});
6566
});
6667
}
6768

6869
private async _evalTokens(tokens: Uint32Array, onToken?: (tokens: number[]) => void, {signal}: { signal?: AbortSignal } = {}) {
69-
const decodeTokens = (tokens: number[]) => this._model.decode(Uint32Array.from(tokens));
70+
const decodeTokens = (tokens: number[]) => this._ctx.decode(Uint32Array.from(tokens));
7071

7172
const stopStrings = this._promptWrapper.getStopStrings();
7273
const stopStringIndexes = Array(stopStrings.length).fill(0);
7374
const skippedChunksQueue: number[] = [];
7475
const res: number[] = [];
7576

7677

77-
for await (const chunk of this._model.evaluate(tokens)) {
78+
for await (const chunk of this._ctx.evaluate(tokens)) {
7879
if (signal?.aborted)
7980
throw new AbortError();
8081

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,16 @@
1-
import {loadBin, type LLAMAModel, type LLAMAContext} from "./utils/getBin.js";
1+
import {LLAMAContext, llamaCppNode} from "./LlamaBins.js";
22

3-
const llamaCppNode = await loadBin();
4-
const {LLAMAModel, LLAMAContext} = llamaCppNode;
5-
6-
export class LlamaModel {
7-
private readonly _model: LLAMAModel;
3+
type LlamaContextConstructorParameters = {prependBos: boolean, ctx: LLAMAContext};
4+
export class LlamaContext {
85
private readonly _ctx: LLAMAContext;
96
private _prependBos: boolean;
107

11-
public constructor({
12-
modelPath, prependBos = true
13-
}: {
14-
modelPath: string, prependBos?: boolean
15-
}) {
16-
this._model = new LLAMAModel(modelPath);
17-
this._ctx = new LLAMAContext(this._model);
8+
/** @internal */
9+
public constructor( {ctx, prependBos}: LlamaContextConstructorParameters ) {
10+
this._ctx = ctx;
1811
this._prependBos = prependBos;
1912
}
2013

21-
public get systemInfo() {
22-
return llamaCppNode.systemInfo();
23-
}
24-
2514
public encode(text: string): Uint32Array {
2615
return this._ctx.encode(text);
2716
}
@@ -34,10 +23,10 @@ export class LlamaModel {
3423
let evalTokens = tokens;
3524

3625
if (this._prependBos) {
37-
const tokensArray = Array.from(tokens);
38-
tokensArray.unshift(llamaCppNode.tokenBos());
26+
const tokenArray = Array.from(tokens);
27+
tokenArray.unshift(llamaCppNode.tokenBos());
3928

40-
evalTokens = Uint32Array.from(tokensArray);
29+
evalTokens = Uint32Array.from(tokenArray);
4130
this._prependBos = false;
4231
}
4332

@@ -56,4 +45,5 @@ export class LlamaModel {
5645
evalTokens = Uint32Array.from([nextToken]);
5746
}
5847
}
48+
5949
}

src/llamaEvaluator/LlamaModel.ts

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import {LlamaContext} from "./LlamaContext.js";
2+
import {LLAMAContext, llamaCppNode, LLAMAModel} from "./LlamaBins.js";
3+
4+
5+
export class LlamaModel {
6+
private readonly _model: LLAMAModel;
7+
private readonly _prependBos: boolean;
8+
9+
public constructor({modelPath, prependBos = true}: { modelPath: string, prependBos?: boolean }) {
10+
this._model = new LLAMAModel(modelPath);
11+
this._prependBos = prependBos;
12+
}
13+
14+
public createContext() {
15+
return new LlamaContext({
16+
ctx: new LLAMAContext(this._model),
17+
prependBos: this._prependBos
18+
});
19+
}
20+
21+
public static get systemInfo() {
22+
return llamaCppNode.systemInfo();
23+
}
24+
}

0 commit comments

Comments
 (0)