|
2 | 2 |
|
3 | 3 | use std::ffi::CString;
|
4 | 4 | use std::io::{self, Write};
|
| 5 | +use std::num::NonZeroU32; |
5 | 6 | use std::path::Path;
|
6 | 7 |
|
7 | 8 | use clap::Parser;
|
@@ -50,8 +51,8 @@ pub struct MtmdCliParams {
|
50 | 51 | #[arg(short = 't', long = "threads", value_name = "N", default_value = "4")]
|
51 | 52 | pub n_threads: i32,
|
52 | 53 | /// Maximum number of tokens in context
|
53 |
| - #[arg(long = "n-tokens", value_name = "N", default_value = "2048")] |
54 |
| - pub n_tokens: usize, |
| 54 | + #[arg(long = "n-tokens", value_name = "N", default_value = "4096")] |
| 55 | + pub n_tokens: NonZeroU32, |
55 | 56 | /// Chat template to use, default template if not provided
|
56 | 57 | #[arg(long = "chat-template", value_name = "TEMPLATE")]
|
57 | 58 | pub chat_template: Option<String>,
|
@@ -111,7 +112,7 @@ impl MtmdCliContext {
|
111 | 112 | .chat_template(params.chat_template.as_deref())
|
112 | 113 | .map_err(|e| format!("Failed to get chat template: {e}"))?;
|
113 | 114 |
|
114 |
| - let batch = LlamaBatch::new(params.n_tokens, 1); |
| 115 | + let batch = LlamaBatch::new(params.n_tokens.get() as usize, 1); |
115 | 116 |
|
116 | 117 | Ok(Self {
|
117 | 118 | mtmd_ctx,
|
@@ -285,7 +286,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
285 | 286 | // Create context
|
286 | 287 | let context_params = LlamaContextParams::default()
|
287 | 288 | .with_n_threads(params.n_threads)
|
288 |
| - .with_n_batch(1); |
| 289 | + .with_n_batch(1) |
| 290 | + .with_n_ctx(Some(params.n_tokens)); |
289 | 291 | let mut context = model.new_context(&backend, context_params)?;
|
290 | 292 |
|
291 | 293 | // Create sampler
|
|
0 commit comments