Skip to content

Commit 4e274ce

Browse files
authored
feat: load conversation history into a LlamaChatSession (#51)
1 parent 9c8c42b commit 4e274ce

10 files changed

+177
-13
lines changed

README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,34 @@ const a2 = await session.prompt(q2);
104104
console.log("AI: " + a2);
105105
```
106106

107+
##### Load existing conversation history
108+
```typescript
109+
import {fileURLToPath} from "url";
110+
import path from "path";
111+
import {LlamaModel, LlamaContext, LlamaChatSession} from "node-llama-cpp";
112+
113+
const __dirname = path.dirname(fileURLToPath(import.meta.url));
114+
115+
const model = new LlamaModel({
116+
modelPath: path.join(__dirname, "models", "codellama-13b.Q3_K_M.gguf")
117+
})
118+
const context = new LlamaContext({model});
119+
const session = new LlamaChatSession({
120+
context,
121+
conversationHistory: [{
122+
prompt: `Remember the number 6 as "The number"`,
123+
response: "OK. I'll remember it"
124+
}]
125+
});
126+
127+
128+
const q2 = 'What is "The number"?';
129+
console.log("User: " + q2);
130+
131+
const a2 = await session.prompt(q2);
132+
console.log("AI: " + a2);
133+
```
134+
107135
#### Raw
108136
```typescript
109137
import {fileURLToPath} from "url";

src/ChatPromptWrapper.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,13 @@ export abstract class ChatPromptWrapper {
1414
public getStopStrings(): string[] {
1515
return [];
1616
}
17+
18+
public getDefaultStopString(): string {
19+
const stopString = this.getStopStrings()[0];
20+
21+
if (stopString == null || stopString.length === 0)
22+
throw new Error(`Prompt wrapper "${this.wrapperName}" has no stop strings`);
23+
24+
return stopString;
25+
}
1726
}

src/chatWrappers/ChatMLPromptWrapper.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,8 @@ export class ChatMLPromptWrapper extends ChatPromptWrapper {
2121
public override getStopStrings(): string[] {
2222
return ["<|im_end|>"];
2323
}
24+
25+
public override getDefaultStopString(): string {
26+
return "<|im_end|>";
27+
}
2428
}

src/chatWrappers/GeneralChatPromptWrapper.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,15 @@ export class GeneralChatPromptWrapper extends ChatPromptWrapper {
3232
];
3333
}
3434

35+
public override getDefaultStopString(): string {
36+
return `\n\n### ${this._instructionName}`;
37+
}
38+
3539
private _getPromptPrefix(lastStopString: string | null, lastStopStringSuffix: string | null) {
3640
return getTextCompletion(
3741
lastStopString === "<end>"
3842
? lastStopStringSuffix
39-
: (lastStopString + (lastStopStringSuffix ?? "")),
43+
: ((lastStopString ?? "") + (lastStopStringSuffix ?? "")),
4044
[
4145
`\n\n### ${this._instructionName}:\n\n`,
4246
`### ${this._instructionName}:\n\n`

src/chatWrappers/LlamaChatPromptWrapper.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,8 @@ export class LlamaChatPromptWrapper extends ChatPromptWrapper {
2121
public override getStopStrings(): string[] {
2222
return ["</s>"];
2323
}
24+
25+
public override getDefaultStopString(): string {
26+
return "</s>";
27+
}
2428
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import {ChatPromptWrapper} from "../ChatPromptWrapper.js";
2+
import {defaultChatSystemPrompt} from "../config.js";
3+
import {ConversationInteraction} from "../types.js";
4+
5+
6+
/**
7+
* Generate context text to load into a model context from a conversation history.
8+
* @param {ChatPromptWrapper} chatPromptWrapper
9+
* @param {ConversationInteraction[]} conversationHistory
10+
* @param {object} [options]
11+
* @param {string} [options.systemPrompt]
12+
* @param {number} [options.currentPromptIndex]
13+
* @param {string | null} [options.lastStopString]
14+
* @param {string | null} [options.lastStopStringSuffix]
15+
* @returns {{text: string, stopString: (string | null), stopStringSuffix: (string | null)}}
16+
*/
17+
export function generateContextTextFromConversationHistory(
18+
chatPromptWrapper: ChatPromptWrapper,
19+
conversationHistory: readonly ConversationInteraction[],
20+
{
21+
systemPrompt = defaultChatSystemPrompt, currentPromptIndex = 0, lastStopString = null, lastStopStringSuffix = null
22+
}: {
23+
systemPrompt?: string, currentPromptIndex?: number, lastStopString?: string | null, lastStopStringSuffix?: string | null
24+
} = {}
25+
): {
26+
text: string;
27+
stopString: string | null;
28+
stopStringSuffix: string | null;
29+
} {
30+
let res = "";
31+
32+
for (let i = 0; i < conversationHistory.length; i++) {
33+
const interaction = conversationHistory[i];
34+
const wrappedPrompt = chatPromptWrapper.wrapPrompt(interaction.prompt, {
35+
systemPrompt,
36+
promptIndex: currentPromptIndex,
37+
lastStopString,
38+
lastStopStringSuffix
39+
});
40+
const stopStrings = chatPromptWrapper.getStopStrings();
41+
const defaultStopString = chatPromptWrapper.getDefaultStopString();
42+
const stopStringsToCheckInResponse = new Set([...stopStrings, defaultStopString]);
43+
44+
currentPromptIndex++;
45+
lastStopString = null;
46+
lastStopStringSuffix = null;
47+
48+
res += wrappedPrompt;
49+
50+
for (const stopString of stopStringsToCheckInResponse) {
51+
if (interaction.response.includes(stopString)) {
52+
console.error(
53+
`Stop string "${stopString}" was found in model response of conversation interaction index ${i}`,
54+
{interaction, stopString}
55+
);
56+
throw new Error("A stop string cannot be in a conversation history interaction model response");
57+
}
58+
}
59+
60+
res += interaction.response;
61+
res += defaultStopString;
62+
lastStopString = defaultStopString;
63+
lastStopStringSuffix = "";
64+
}
65+
66+
return {
67+
text: res,
68+
stopString: lastStopString,
69+
stopStringSuffix: lastStopStringSuffix
70+
};
71+
}

src/index.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import {GeneralChatPromptWrapper} from "./chatWrappers/GeneralChatPromptWrapper.
1010
import {ChatMLPromptWrapper} from "./chatWrappers/ChatMLPromptWrapper.js";
1111
import {getChatWrapperByBos} from "./chatWrappers/createChatWrapperByBos.js";
1212

13-
import {type Token} from "./types.js";
13+
import {type ConversationInteraction, type Token} from "./types.js";
1414

1515

1616
export {
@@ -22,6 +22,7 @@ export {
2222
type LlamaContextOptions,
2323
LlamaChatSession,
2424
type LlamaChatSessionOptions,
25+
type ConversationInteraction,
2526
AbortError,
2627
ChatPromptWrapper,
2728
EmptyChatPromptWrapper,

src/llamaEvaluator/LlamaChatSession.ts

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ import {ChatPromptWrapper} from "../ChatPromptWrapper.js";
44
import {AbortError} from "../AbortError.js";
55
import {GeneralChatPromptWrapper} from "../chatWrappers/GeneralChatPromptWrapper.js";
66
import {getChatWrapperByBos} from "../chatWrappers/createChatWrapperByBos.js";
7-
import {Token} from "../types.js";
7+
import {ConversationInteraction, Token} from "../types.js";
8+
import {generateContextTextFromConversationHistory} from "../chatWrappers/generateContextTextFromConversationHistory.js";
89
import {LlamaModel} from "./LlamaModel.js";
910
import {LlamaContext} from "./LlamaContext.js";
1011

@@ -15,7 +16,10 @@ export type LlamaChatSessionOptions = {
1516
context: LlamaContext,
1617
printLLamaSystemInfo?: boolean,
1718
promptWrapper?: ChatPromptWrapper | "auto",
18-
systemPrompt?: string
19+
systemPrompt?: string,
20+
21+
/** Conversation history to load into the context to continue an existing conversation */
22+
conversationHistory?: readonly ConversationInteraction[]
1923
};
2024

2125
export class LlamaChatSession {
@@ -26,17 +30,22 @@ export class LlamaChatSession {
2630
private _initialized: boolean = false;
2731
private _lastStopString: string | null = null;
2832
private _lastStopStringSuffix: string | null = null;
33+
private _conversationHistoryToLoad: readonly ConversationInteraction[] | null = null;
2934
private readonly _ctx: LlamaContext;
3035

3136
public constructor({
3237
context,
3338
printLLamaSystemInfo = false,
3439
promptWrapper = new GeneralChatPromptWrapper(),
35-
systemPrompt = defaultChatSystemPrompt
40+
systemPrompt = defaultChatSystemPrompt,
41+
conversationHistory
3642
}: LlamaChatSessionOptions) {
3743
this._ctx = context;
3844
this._printLLamaSystemInfo = printLLamaSystemInfo;
3945
this._systemPrompt = systemPrompt;
46+
this._conversationHistoryToLoad = (conversationHistory != null && conversationHistory.length > 0)
47+
? conversationHistory
48+
: null;
4049

4150
if (promptWrapper === "auto") {
4251
const chatWrapper = getChatWrapperByBos(context.getBosString());
@@ -76,7 +85,32 @@ export class LlamaChatSession {
7685
await this.init();
7786

7887
return await withLock(this, "prompt", async () => {
79-
const promptText = this._promptWrapper.wrapPrompt(prompt, {
88+
let promptText = "";
89+
90+
if (this._promptIndex == 0 && this._conversationHistoryToLoad != null) {
91+
const {text, stopString, stopStringSuffix} =
92+
generateContextTextFromConversationHistory(this._promptWrapper, this._conversationHistoryToLoad, {
93+
systemPrompt: this._systemPrompt,
94+
currentPromptIndex: this._promptIndex,
95+
lastStopString: this._lastStopString,
96+
lastStopStringSuffix: this._promptIndex == 0
97+
? (
98+
this._ctx.prependBos
99+
? this._ctx.getBosString()
100+
: null
101+
)
102+
: this._lastStopStringSuffix
103+
});
104+
105+
promptText += text;
106+
this._lastStopString = stopString;
107+
this._lastStopStringSuffix = stopStringSuffix;
108+
this._promptIndex += this._conversationHistoryToLoad.length;
109+
110+
this._conversationHistoryToLoad = null;
111+
}
112+
113+
promptText += this._promptWrapper.wrapPrompt(prompt, {
80114
systemPrompt: this._systemPrompt,
81115
promptIndex: this._promptIndex,
82116
lastStopString: this._lastStopString,

src/llamaEvaluator/LlamaContext.ts

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,19 @@ export type LlamaContextOptions = {
1313

1414
export class LlamaContext {
1515
private readonly _ctx: LLAMAContext;
16-
private _prependBos: boolean;
16+
private readonly _prependBos: boolean;
17+
private _prependTokens: Token[];
1718

1819
public constructor({model, grammar, prependBos = true}: LlamaContextOptions) {
1920
this._ctx = new LLAMAContext(model._model, removeNullFields({
2021
grammar: grammar?._grammar
2122
}));
2223
this._prependBos = prependBos;
24+
this._prependTokens = [];
25+
26+
if (prependBos) {
27+
this._prependTokens.unshift(this._ctx.tokenBos());
28+
}
2329
}
2430

2531
public encode(text: string): Uint32Array {
@@ -115,19 +121,18 @@ export class LlamaContext {
115121
return this._ctx.getTokenString(nlToken);
116122
}
117123

118-
public getContextSize() {
124+
public getContextSize(): number {
119125
return this._ctx.getContextSize();
120126
}
121127

122128
public async *evaluate(tokens: Uint32Array): AsyncGenerator<Token, void> {
123129
let evalTokens = tokens;
124130

125-
if (this._prependBos) {
126-
const tokenArray: Token[] = Array.from(tokens);
127-
tokenArray.unshift(this._ctx.tokenBos());
131+
if (this._prependTokens.length > 0) {
132+
const tokenArray: Token[] = this._prependTokens.concat(Array.from(tokens));
128133

129134
evalTokens = Uint32Array.from(tokenArray);
130-
this._prependBos = false;
135+
this._prependTokens = [];
131136
}
132137

133138
// eslint-disable-next-line no-constant-condition
@@ -145,5 +150,4 @@ export class LlamaContext {
145150
evalTokens = Uint32Array.from([nextToken]);
146151
}
147152
}
148-
149153
}

src/types.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,6 @@
11
export type Token = number;
2+
3+
export type ConversationInteraction = {
4+
prompt: string,
5+
response: string
6+
};

0 commit comments

Comments
 (0)