Skip to content

Commit 9c3dbdb

Browse files
authored
Merge pull request #212 from jasonmccampbell/jason/adnl-cli-args
Added a couple of additional CLI arguments for benchmarking purposes.
2 parents 5bec82d + c484065 commit 9c3dbdb

File tree

1 file changed

+49
-5
lines changed

1 file changed

+49
-5
lines changed

simple/src/main.rs

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@ struct Args {
3232
#[command(subcommand)]
3333
model: Model,
3434
/// The prompt
35-
#[clap(default_value = "Hello my name is")]
36-
prompt: String,
35+
#[clap(short = 'p', long)]
36+
prompt: Option<String>,
37+
/// Read the prompt from a file
38+
#[clap(short = 'f', long, help = "prompt file to start generation")]
39+
file: Option<String>,
3740
/// set the length of the prompt + output in tokens
3841
#[arg(long, default_value_t = 32)]
3942
n_len: i32,
@@ -44,6 +47,25 @@ struct Args {
4447
#[cfg(feature = "cublas")]
4548
#[clap(long)]
4649
disable_gpu: bool,
50+
#[arg(short = 's', long, help = "RNG seed (default: 1234)")]
51+
seed: Option<u32>,
52+
#[arg(
53+
short = 't',
54+
long,
55+
help = "number of threads to use during generation (default: use all available threads)"
56+
)]
57+
threads: Option<u32>,
58+
#[arg(
59+
long,
60+
help = "number of threads to use during batch and prompt processing (default: use all available threads)"
61+
)]
62+
threads_batch: Option<u32>,
63+
#[arg(
64+
short = 'c',
65+
long,
66+
help = "size of the prompt context (default: loaded from themodel)"
67+
)]
68+
ctx_size: Option<NonZeroU32>,
4769
}
4870

4971
/// Parse a single key-value pair
@@ -100,9 +122,14 @@ fn main() -> Result<()> {
100122
n_len,
101123
model,
102124
prompt,
125+
file,
103126
#[cfg(feature = "cublas")]
104127
disable_gpu,
105128
key_value_overrides,
129+
seed,
130+
threads,
131+
threads_batch,
132+
ctx_size,
106133
} = Args::parse();
107134

108135
// init LLM
@@ -120,6 +147,17 @@ fn main() -> Result<()> {
120147
LlamaModelParams::default()
121148
};
122149

150+
let prompt = if let Some(str) = prompt {
151+
if file.is_some() {
152+
bail!("either prompt or file must be specified, but not both")
153+
}
154+
str
155+
} else if let Some(file) = file {
156+
std::fs::read_to_string(&file).with_context(|| format!("unable to read {file}"))?
157+
} else {
158+
"Hello my name is".to_string()
159+
};
160+
123161
let mut model_params = pin!(model_params);
124162

125163
for (k, v) in &key_value_overrides {
@@ -135,9 +173,15 @@ fn main() -> Result<()> {
135173
.with_context(|| "unable to load model")?;
136174

137175
// initialize the context
138-
let ctx_params = LlamaContextParams::default()
139-
.with_n_ctx(NonZeroU32::new(2048))
140-
.with_seed(1234);
176+
let mut ctx_params = LlamaContextParams::default()
177+
.with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap())))
178+
.with_seed(seed.unwrap_or(1234));
179+
if let Some(threads) = threads {
180+
ctx_params = ctx_params.with_n_threads(threads);
181+
}
182+
if let Some(threads_batch) = threads_batch.or(threads) {
183+
ctx_params = ctx_params.with_n_threads_batch(threads_batch);
184+
}
141185

142186
let mut ctx = model
143187
.new_context(&backend, ctx_params)

0 commit comments

Comments
 (0)