Skip to content

Commit b9a2a61

Browse files
authored
feat: general chat wrapper (#2)
1 parent cf0d4e6 commit b9a2a61

File tree

4 files changed

+43
-9
lines changed

4 files changed

+43
-9
lines changed

src/LlamaChatSession.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ import {defaultChatSystemPrompt} from "./config.js";
22
import {withLock} from "./utils/withLock.js";
33
import {LlamaModel} from "./LlamaModel.js";
44
import {ChatPromptWrapper} from "./ChatPromptWrapper.js";
5-
import {LlamaChatPromptWrapper} from "./chatWrappers/LlamaChatPromptWrapper.js";
65
import {AbortError} from "./AbortError.js";
6+
import {GeneralChatPromptWrapper} from "./chatWrappers/GeneralChatPromptWrapper.js";
77

88
const UNKNOWN_UNICODE_CHAR = "�";
99

@@ -18,7 +18,7 @@ export class LlamaChatSession {
1818
public constructor({
1919
model,
2020
printLLamaSystemInfo = false,
21-
promptWrapper = new LlamaChatPromptWrapper(),
21+
promptWrapper = new GeneralChatPromptWrapper(),
2222
systemPrompt = defaultChatSystemPrompt
2323
}: {
2424
model: LlamaModel,
@@ -80,7 +80,7 @@ export class LlamaChatSession {
8080

8181
const tokenStr = decodeTokens([chunk]);
8282
const {shouldReturn, skipTokenEvent} = this._checkStopString(tokenStr, stopStringIndexes);
83-
83+
8484
if (shouldReturn)
8585
return decodeTokens(res);
8686

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import {ChatPromptWrapper} from "../ChatPromptWrapper.js";
2+
3+
export class GeneralChatPromptWrapper extends ChatPromptWrapper {
4+
public override wrapPrompt(prompt: string, {systemPrompt, promptIndex}: { systemPrompt: string, promptIndex: number }) {
5+
const conversationPrompt = "\n\n### Human:\n\n" + prompt + "\n\n### Assistant:\n\n";
6+
7+
return promptIndex === 0 ? systemPrompt + conversationPrompt : conversationPrompt;
8+
}
9+
10+
public override getStopStrings(): string[] {
11+
return ["### Human:", "Human:", "### Assistant:", "Assistant:", "<end>"];
12+
}
13+
}

src/cli/commands/ChatCommand.ts

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ import chalk from "chalk";
55
import withOra from "../../utils/withOra.js";
66
import {defaultChatSystemPrompt} from "../../config.js";
77
import {LlamaChatPromptWrapper} from "../../chatWrappers/LlamaChatPromptWrapper.js";
8+
import {GeneralChatPromptWrapper} from "../../chatWrappers/GeneralChatPromptWrapper.js";
89

910
type ChatCommand = {
1011
model: string,
1112
systemInfo: boolean,
12-
systemPrompt: string
13+
systemPrompt: string,
14+
wrapper: string
1315
};
1416

1517
export const ChatCommand: CommandModule<object, ChatCommand> = {
@@ -37,11 +39,18 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
3739
"System prompt to use against the model. " +
3840
"[default value: " + defaultChatSystemPrompt.split("\n").join(" ") + "]",
3941
group: "Optional:"
42+
})
43+
.option("wrapper", {
44+
type: "string",
45+
default: "general",
46+
choices: ["general", "llama"],
47+
description: "Chat wrapper to use",
48+
group: "Optional:"
4049
});
4150
},
42-
async handler({model, systemInfo, systemPrompt}) {
51+
async handler({model, systemInfo, systemPrompt, wrapper}) {
4352
try {
44-
await RunChat({model, systemInfo, systemPrompt});
53+
await RunChat({model, systemInfo, systemPrompt, wrapper});
4554
} catch (err) {
4655
console.error(err);
4756
process.exit(1);
@@ -50,7 +59,7 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
5059
};
5160

5261

53-
async function RunChat({model: modelArg, systemInfo, systemPrompt}: ChatCommand) {
62+
async function RunChat({model: modelArg, systemInfo, systemPrompt, wrapper}: ChatCommand) {
5463
const {LlamaChatSession} = await import("../../LlamaChatSession.js");
5564
const {LlamaModel} = await import("../../LlamaModel.js");
5665

@@ -61,7 +70,7 @@ async function RunChat({model: modelArg, systemInfo, systemPrompt}: ChatCommand)
6170
model,
6271
printLLamaSystemInfo: systemInfo,
6372
systemPrompt,
64-
promptWrapper: new LlamaChatPromptWrapper()
73+
promptWrapper: createChatWrapper(wrapper)
6574
});
6675

6776
await withOra({
@@ -99,3 +108,13 @@ async function RunChat({model: modelArg, systemInfo, systemPrompt}: ChatCommand)
99108
console.log();
100109
}
101110
}
111+
112+
function createChatWrapper(wrapper: string) {
113+
switch (wrapper) {
114+
case "general":
115+
return new GeneralChatPromptWrapper();
116+
case "llama":
117+
return new LlamaChatPromptWrapper();
118+
}
119+
throw new Error("Unknown wrapper: " + wrapper);
120+
}

src/index.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ import {AbortError} from "./AbortError.js";
44
import {ChatPromptWrapper} from "./ChatPromptWrapper.js";
55
import {EmptyChatPromptWrapper} from "./chatWrappers/EmptyChatPromptWrapper.js";
66
import {LlamaChatPromptWrapper} from "./chatWrappers/LlamaChatPromptWrapper.js";
7+
import {GeneralChatPromptWrapper} from "./chatWrappers/GeneralChatPromptWrapper.js";
78

89
export {
910
LlamaModel,
1011
LlamaChatSession,
1112
AbortError,
1213
ChatPromptWrapper,
1314
EmptyChatPromptWrapper,
14-
LlamaChatPromptWrapper
15+
LlamaChatPromptWrapper,
16+
GeneralChatPromptWrapper
1517
};

0 commit comments

Comments
 (0)