@@ -5,11 +5,13 @@ import chalk from "chalk";
5
5
import withOra from "../../utils/withOra.js" ;
6
6
import { defaultChatSystemPrompt } from "../../config.js" ;
7
7
import { LlamaChatPromptWrapper } from "../../chatWrappers/LlamaChatPromptWrapper.js" ;
8
+ import { GeneralChatPromptWrapper } from "../../chatWrappers/GeneralChatPromptWrapper.js" ;
8
9
9
10
type ChatCommand = {
10
11
model : string ,
11
12
systemInfo : boolean ,
12
- systemPrompt : string
13
+ systemPrompt : string ,
14
+ wrapper : string
13
15
} ;
14
16
15
17
export const ChatCommand : CommandModule < object , ChatCommand > = {
@@ -37,11 +39,18 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
37
39
"System prompt to use against the model. " +
38
40
"[default value: " + defaultChatSystemPrompt . split ( "\n" ) . join ( " " ) + "]" ,
39
41
group : "Optional:"
42
+ } )
43
+ . option ( "wrapper" , {
44
+ type : "string" ,
45
+ default : "general" ,
46
+ choices : [ "general" , "llama" ] ,
47
+ description : "Chat wrapper to use" ,
48
+ group : "Optional:"
40
49
} ) ;
41
50
} ,
42
- async handler ( { model, systemInfo, systemPrompt} ) {
51
+ async handler ( { model, systemInfo, systemPrompt, wrapper } ) {
43
52
try {
44
- await RunChat ( { model, systemInfo, systemPrompt} ) ;
53
+ await RunChat ( { model, systemInfo, systemPrompt, wrapper } ) ;
45
54
} catch ( err ) {
46
55
console . error ( err ) ;
47
56
process . exit ( 1 ) ;
@@ -50,7 +59,7 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
50
59
} ;
51
60
52
61
53
- async function RunChat ( { model : modelArg , systemInfo, systemPrompt} : ChatCommand ) {
62
+ async function RunChat ( { model : modelArg , systemInfo, systemPrompt, wrapper } : ChatCommand ) {
54
63
const { LlamaChatSession} = await import ( "../../LlamaChatSession.js" ) ;
55
64
const { LlamaModel} = await import ( "../../LlamaModel.js" ) ;
56
65
@@ -61,7 +70,7 @@ async function RunChat({model: modelArg, systemInfo, systemPrompt}: ChatCommand)
61
70
model,
62
71
printLLamaSystemInfo : systemInfo ,
63
72
systemPrompt,
64
- promptWrapper : new LlamaChatPromptWrapper ( )
73
+ promptWrapper : createChatWrapper ( wrapper )
65
74
} ) ;
66
75
67
76
await withOra ( {
@@ -99,3 +108,13 @@ async function RunChat({model: modelArg, systemInfo, systemPrompt}: ChatCommand)
99
108
console . log ( ) ;
100
109
}
101
110
}
111
+
112
+ function createChatWrapper ( wrapper : string ) {
113
+ switch ( wrapper ) {
114
+ case "general" :
115
+ return new GeneralChatPromptWrapper ( ) ;
116
+ case "llama" :
117
+ return new LlamaChatPromptWrapper ( ) ;
118
+ }
119
+ throw new Error ( "Unknown wrapper: " + wrapper ) ;
120
+ }
0 commit comments