@@ -32,8 +32,11 @@ struct Args {
3232 #[ command( subcommand) ]
3333 model : Model ,
3434 /// The prompt
35- #[ clap( default_value = "Hello my name is" ) ]
36- prompt : String ,
35+ #[ clap( short = 'p' , long) ]
36+ prompt : Option < String > ,
37+ /// Read the prompt from a file
38+ #[ clap( short = 'f' , long, help = "prompt file to start generation" ) ]
39+ file : Option < String > ,
3740 /// set the length of the prompt + output in tokens
3841 #[ arg( long, default_value_t = 32 ) ]
3942 n_len : i32 ,
@@ -44,6 +47,25 @@ struct Args {
4447 #[ cfg( feature = "cublas" ) ]
4548 #[ clap( long) ]
4649 disable_gpu : bool ,
50+ #[ arg( short = 's' , long, help = "RNG seed (default: 1234)" ) ]
51+ seed : Option < u32 > ,
52+ #[ arg(
53+ short = 't' ,
54+ long,
55+ help = "number of threads to use during generation (default: use all available threads)"
56+ ) ]
57+ threads : Option < u32 > ,
58+ #[ arg(
59+ long,
60+ help = "number of threads to use during batch and prompt processing (default: use all available threads)"
61+ ) ]
62+ threads_batch : Option < u32 > ,
63+ #[ arg(
64+ short = 'c' ,
65+ long,
66+ help = "size of the prompt context (default: loaded from themodel)"
67+ ) ]
68+ ctx_size : Option < NonZeroU32 > ,
4769}
4870
4971/// Parse a single key-value pair
@@ -100,9 +122,14 @@ fn main() -> Result<()> {
100122 n_len,
101123 model,
102124 prompt,
125+ file,
103126 #[ cfg( feature = "cublas" ) ]
104127 disable_gpu,
105128 key_value_overrides,
129+ seed,
130+ threads,
131+ threads_batch,
132+ ctx_size,
106133 } = Args :: parse ( ) ;
107134
108135 // init LLM
@@ -120,6 +147,17 @@ fn main() -> Result<()> {
120147 LlamaModelParams :: default ( )
121148 } ;
122149
150+ let prompt = if let Some ( str) = prompt {
151+ if file. is_some ( ) {
152+ bail ! ( "either prompt or file must be specified, but not both" )
153+ }
154+ str
155+ } else if let Some ( file) = file {
156+ std:: fs:: read_to_string ( & file) . with_context ( || format ! ( "unable to read {file}" ) ) ?
157+ } else {
158+ "Hello my name is" . to_string ( )
159+ } ;
160+
123161 let mut model_params = pin ! ( model_params) ;
124162
125163 for ( k, v) in & key_value_overrides {
@@ -135,9 +173,15 @@ fn main() -> Result<()> {
135173 . with_context ( || "unable to load model" ) ?;
136174
137175 // initialize the context
138- let ctx_params = LlamaContextParams :: default ( )
139- . with_n_ctx ( NonZeroU32 :: new ( 2048 ) )
140- . with_seed ( 1234 ) ;
176+ let mut ctx_params = LlamaContextParams :: default ( )
177+ . with_n_ctx ( ctx_size. or ( Some ( NonZeroU32 :: new ( 2048 ) . unwrap ( ) ) ) )
178+ . with_seed ( seed. unwrap_or ( 1234 ) ) ;
179+ if let Some ( threads) = threads {
180+ ctx_params = ctx_params. with_n_threads ( threads) ;
181+ }
182+ if let Some ( threads_batch) = threads_batch. or ( threads) {
183+ ctx_params = ctx_params. with_n_threads_batch ( threads_batch) ;
184+ }
141185
142186 let mut ctx = model
143187 . new_context ( & backend, ctx_params)
0 commit comments