Skip to content

Commit 624fa30

Browse files
authored
feat: export resolveChatWrapperBasedOnWrapperTypeName (#165)
1 parent ede69c1 commit 624fa30

File tree

3 files changed

+80
-57
lines changed

3 files changed

+80
-57
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import {ModelTypeDescription} from "../AddonTypes.js";
2+
import {GeneralChatWrapper} from "../../chatWrappers/GeneralChatWrapper.js";
3+
import {LlamaChatWrapper} from "../../chatWrappers/LlamaChatWrapper.js";
4+
import {AlpacaChatWrapper} from "../../chatWrappers/AlpacaChatWrapper.js";
5+
import {FunctionaryChatWrapper} from "../../chatWrappers/FunctionaryChatWrapper.js";
6+
import {ChatMLChatWrapper} from "../../chatWrappers/ChatMLChatWrapper.js";
7+
import {FalconChatWrapper} from "../../chatWrappers/FalconChatWrapper.js";
8+
import {resolveChatWrapperBasedOnModel} from "../../chatWrappers/resolveChatWrapperBasedOnModel.js";
9+
10+
export const chatWrapperTypeNames = [
11+
"auto", "general", "llamaChat", "alpacaChat", "functionary", "chatML", "falconChat"
12+
] as const;
13+
export type ChatWrapperTypeName = (typeof chatWrapperTypeNames)[number];
14+
15+
const chatWrappers = {
16+
"general": GeneralChatWrapper,
17+
"llamaChat": LlamaChatWrapper,
18+
"alpacaChat": AlpacaChatWrapper,
19+
"functionary": FunctionaryChatWrapper,
20+
"chatML": ChatMLChatWrapper,
21+
"falconChat": FalconChatWrapper
22+
} as const satisfies Record<Exclude<ChatWrapperTypeName, "auto">, any>;
23+
const chatWrapperToConfigType = new Map(
24+
Object.entries(chatWrappers).map(([configType, Wrapper]) => [Wrapper, configType])
25+
);
26+
27+
export function resolveChatWrapperBasedOnWrapperTypeName(configType: ChatWrapperTypeName, {
28+
bosString,
29+
filename,
30+
typeDescription,
31+
customWrapperSettings
32+
}: {
33+
bosString?: string | null,
34+
filename?: string,
35+
typeDescription?: ModelTypeDescription,
36+
customWrapperSettings?: {
37+
[wrapper in keyof typeof chatWrappers]?: ConstructorParameters<(typeof chatWrappers)[wrapper]>[0]
38+
}
39+
} = {}) {
40+
if (Object.hasOwn(chatWrappers, configType)) {
41+
const Wrapper = chatWrappers[configType as keyof typeof chatWrappers];
42+
const wrapperSettings: ConstructorParameters<typeof Wrapper>[0] | undefined =
43+
customWrapperSettings?.[configType as keyof typeof chatWrappers];
44+
45+
return new Wrapper(wrapperSettings);
46+
}
47+
48+
if (configType === "auto") {
49+
const chatWrapper = resolveChatWrapperBasedOnModel({
50+
bosString,
51+
filename,
52+
typeDescription
53+
});
54+
55+
if (chatWrapper != null) {
56+
const resolvedConfigType = chatWrapperToConfigType.get(chatWrapper);
57+
const wrapperSettings: ConstructorParameters<typeof chatWrapper>[0] | undefined = resolvedConfigType == null
58+
? undefined
59+
: customWrapperSettings?.[resolvedConfigType as keyof typeof chatWrappers];
60+
61+
return new chatWrapper(wrapperSettings);
62+
}
63+
64+
return new GeneralChatWrapper(customWrapperSettings?.general);
65+
}
66+
67+
throw new Error("Unknown wrapper config: " + configType);
68+
}

src/cli/commands/ChatCommand.ts

Lines changed: 6 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,20 @@ import {CommandModule} from "yargs";
55
import chalk from "chalk";
66
import fs from "fs-extra";
77
import {chatCommandHistoryFilePath, defaultChatSystemPrompt} from "../../config.js";
8-
import {LlamaChatWrapper} from "../../chatWrappers/LlamaChatWrapper.js";
9-
import {GeneralChatWrapper} from "../../chatWrappers/GeneralChatWrapper.js";
10-
import {ChatMLChatWrapper} from "../../chatWrappers/ChatMLChatWrapper.js";
11-
import {resolveChatWrapperBasedOnModel} from "../../chatWrappers/resolveChatWrapperBasedOnModel.js";
12-
import {ChatWrapper} from "../../ChatWrapper.js";
13-
import {FalconChatWrapper} from "../../chatWrappers/FalconChatWrapper.js";
148
import {getIsInDocumentationMode} from "../../state.js";
159
import {ReplHistory} from "../../utils/ReplHistory.js";
1610
import withStatusLogs from "../../utils/withStatusLogs.js";
17-
import {AlpacaChatWrapper} from "../../chatWrappers/AlpacaChatWrapper.js";
18-
import {FunctionaryChatWrapper} from "../../chatWrappers/FunctionaryChatWrapper.js";
1911
import {defineChatSessionFunction} from "../../evaluator/LlamaChatSession/utils/defineChatSessionFunction.js";
2012
import {getLlama} from "../../bindings/getLlama.js";
2113
import {LlamaGrammar} from "../../evaluator/LlamaGrammar.js";
22-
import {ModelTypeDescription} from "../../bindings/AddonTypes.js";
2314
import {LlamaChatSession} from "../../evaluator/LlamaChatSession/LlamaChatSession.js";
2415
import {LlamaModel} from "../../evaluator/LlamaModel.js";
2516
import {LlamaContext} from "../../evaluator/LlamaContext/LlamaContext.js";
2617
import {LlamaJsonSchemaGrammar} from "../../evaluator/LlamaJsonSchemaGrammar.js";
2718
import {LlamaLogLevel} from "../../bindings/types.js";
28-
29-
const modelWrappers = ["auto", "general", "llamaChat", "alpacaChat", "functionary", "chatML", "falconChat"] as const;
19+
import {
20+
resolveChatWrapperBasedOnWrapperTypeName, chatWrapperTypeNames, ChatWrapperTypeName
21+
} from "../../bindings/utils/resolveChatWrapperBasedOnWrapperTypeName.js";
3022

3123
type ChatCommand = {
3224
model: string,
@@ -35,7 +27,7 @@ type ChatCommand = {
3527
systemPromptFile?: string,
3628
prompt?: string,
3729
promptFile?: string,
38-
wrapper: (typeof modelWrappers)[number],
30+
wrapper: ChatWrapperTypeName,
3931
contextSize: number,
4032
batchSize?: number,
4133
grammar: "text" | Parameters<typeof LlamaGrammar.getFor>[1],
@@ -108,7 +100,7 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
108100
alias: "w",
109101
type: "string",
110102
default: "auto" as ChatCommand["wrapper"],
111-
choices: modelWrappers,
103+
choices: chatWrapperTypeNames,
112104
description: "Chat wrapper to use. Use `auto` to automatically select a wrapper based on the model's BOS token",
113105
group: "Optional:"
114106
})
@@ -343,7 +335,7 @@ async function RunChat({
343335
: undefined;
344336
const bos = model.tokens.bosString; // bos = beginning of sequence
345337
const eos = model.tokens.bosString; // eos = end of sequence
346-
const chatWrapper = getChatWrapper(wrapper, {
338+
const chatWrapper = resolveChatWrapperBasedOnWrapperTypeName(wrapper, {
347339
bosString: bos,
348340
filename: model.filename,
349341
typeDescription: model.typeDescription
@@ -477,46 +469,3 @@ const defaultEnvironmentFunctions = {
477469
}
478470
})
479471
};
480-
481-
function getChatWrapper(wrapper: ChatCommand["wrapper"], {
482-
bosString,
483-
filename,
484-
typeDescription
485-
}: {
486-
bosString?: string | null,
487-
filename?: string,
488-
typeDescription?: ModelTypeDescription
489-
}): ChatWrapper {
490-
switch (wrapper) {
491-
case "general":
492-
return new GeneralChatWrapper();
493-
case "llamaChat":
494-
return new LlamaChatWrapper();
495-
case "alpacaChat":
496-
return new AlpacaChatWrapper();
497-
case "functionary":
498-
return new FunctionaryChatWrapper();
499-
case "chatML":
500-
return new ChatMLChatWrapper();
501-
case "falconChat":
502-
return new FalconChatWrapper();
503-
default:
504-
}
505-
506-
if (wrapper === "auto") {
507-
const chatWrapper = resolveChatWrapperBasedOnModel({
508-
bosString,
509-
filename,
510-
typeDescription
511-
});
512-
513-
if (chatWrapper != null)
514-
return new chatWrapper();
515-
516-
return new GeneralChatWrapper();
517-
}
518-
519-
void (wrapper satisfies never);
520-
521-
throw new Error("Unknown wrapper: " + wrapper);
522-
}

src/index.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ import {FalconChatWrapper} from "./chatWrappers/FalconChatWrapper.js";
3939
import {AlpacaChatWrapper} from "./chatWrappers/AlpacaChatWrapper.js";
4040
import {FunctionaryChatWrapper} from "./chatWrappers/FunctionaryChatWrapper.js";
4141
import {resolveChatWrapperBasedOnModel} from "./chatWrappers/resolveChatWrapperBasedOnModel.js";
42+
import {
43+
resolveChatWrapperBasedOnWrapperTypeName, chatWrapperTypeNames, type ChatWrapperTypeName
44+
} from "./bindings/utils/resolveChatWrapperBasedOnWrapperTypeName.js";
4245
import {
4346
LlamaText, SpecialToken, BuiltinSpecialToken, isLlamaText, tokenizeText, type LlamaTextJSON, type LlamaTextJSONValue,
4447
type LlamaTextSpecialTokenJSON
@@ -118,6 +121,9 @@ export {
118121
AlpacaChatWrapper,
119122
FunctionaryChatWrapper,
120123
resolveChatWrapperBasedOnModel,
124+
resolveChatWrapperBasedOnWrapperTypeName,
125+
chatWrapperTypeNames,
126+
type ChatWrapperTypeName,
121127
LlamaText,
122128
SpecialToken,
123129
BuiltinSpecialToken,

0 commit comments

Comments
 (0)