@@ -32,8 +32,11 @@ struct Args {
32
32
#[ command( subcommand) ]
33
33
model : Model ,
34
34
/// The prompt
35
- #[ clap( default_value = "Hello my name is" ) ]
36
- prompt : String ,
35
+ #[ clap( short = 'p' , long) ]
36
+ prompt : Option < String > ,
37
+ /// Read the prompt from a file
38
+ #[ clap( short = 'f' , long, help = "prompt file to start generation" ) ]
39
+ file : Option < String > ,
37
40
/// set the length of the prompt + output in tokens
38
41
#[ arg( long, default_value_t = 32 ) ]
39
42
n_len : i32 ,
@@ -44,6 +47,25 @@ struct Args {
44
47
#[ cfg( feature = "cublas" ) ]
45
48
#[ clap( long) ]
46
49
disable_gpu : bool ,
50
+ #[ arg( short = 's' , long, help = "RNG seed (default: 1234)" ) ]
51
+ seed : Option < u32 > ,
52
+ #[ arg(
53
+ short = 't' ,
54
+ long,
55
+ help = "number of threads to use during generation (default: use all available threads)"
56
+ ) ]
57
+ threads : Option < u32 > ,
58
+ #[ arg(
59
+ long,
60
+ help = "number of threads to use during batch and prompt processing (default: use all available threads)"
61
+ ) ]
62
+ threads_batch : Option < u32 > ,
63
+ #[ arg(
64
+ short = 'c' ,
65
+ long,
66
+ help = "size of the prompt context (default: loaded from themodel)"
67
+ ) ]
68
+ ctx_size : Option < NonZeroU32 > ,
47
69
}
48
70
49
71
/// Parse a single key-value pair
@@ -100,9 +122,14 @@ fn main() -> Result<()> {
100
122
n_len,
101
123
model,
102
124
prompt,
125
+ file,
103
126
#[ cfg( feature = "cublas" ) ]
104
127
disable_gpu,
105
128
key_value_overrides,
129
+ seed,
130
+ threads,
131
+ threads_batch,
132
+ ctx_size,
106
133
} = Args :: parse ( ) ;
107
134
108
135
// init LLM
@@ -120,6 +147,17 @@ fn main() -> Result<()> {
120
147
LlamaModelParams :: default ( )
121
148
} ;
122
149
150
+ let prompt = if let Some ( str) = prompt {
151
+ if file. is_some ( ) {
152
+ bail ! ( "either prompt or file must be specified, but not both" )
153
+ }
154
+ str
155
+ } else if let Some ( file) = file {
156
+ std:: fs:: read_to_string ( & file) . with_context ( || format ! ( "unable to read {file}" ) ) ?
157
+ } else {
158
+ "Hello my name is" . to_string ( )
159
+ } ;
160
+
123
161
let mut model_params = pin ! ( model_params) ;
124
162
125
163
for ( k, v) in & key_value_overrides {
@@ -135,9 +173,15 @@ fn main() -> Result<()> {
135
173
. with_context ( || "unable to load model" ) ?;
136
174
137
175
// initialize the context
138
- let ctx_params = LlamaContextParams :: default ( )
139
- . with_n_ctx ( NonZeroU32 :: new ( 2048 ) )
140
- . with_seed ( 1234 ) ;
176
+ let mut ctx_params = LlamaContextParams :: default ( )
177
+ . with_n_ctx ( ctx_size. or ( Some ( NonZeroU32 :: new ( 2048 ) . unwrap ( ) ) ) )
178
+ . with_seed ( seed. unwrap_or ( 1234 ) ) ;
179
+ if let Some ( threads) = threads {
180
+ ctx_params = ctx_params. with_n_threads ( threads) ;
181
+ }
182
+ if let Some ( threads_batch) = threads_batch. or ( threads) {
183
+ ctx_params = ctx_params. with_n_threads_batch ( threads_batch) ;
184
+ }
141
185
142
186
let mut ctx = model
143
187
. new_context ( & backend, ctx_params)
0 commit comments