Skip to content

Commit 62f1511

Browse files
committed
Fix context length in mtmd example
Signed-off-by: Dennis Keck <[email protected]>
1 parent d025465 commit 62f1511

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

examples/mtmd/src/mtmd.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
use std::ffi::CString;
44
use std::io::{self, Write};
5+
use std::num::NonZeroU32;
56
use std::path::Path;
67

78
use clap::Parser;
@@ -50,8 +51,8 @@ pub struct MtmdCliParams {
5051
#[arg(short = 't', long = "threads", value_name = "N", default_value = "4")]
5152
pub n_threads: i32,
5253
/// Maximum number of tokens in context
53-
#[arg(long = "n-tokens", value_name = "N", default_value = "2048")]
54-
pub n_tokens: usize,
54+
#[arg(long = "n-tokens", value_name = "N", default_value = "4096")]
55+
pub n_tokens: NonZeroU32,
5556
/// Chat template to use, default template if not provided
5657
#[arg(long = "chat-template", value_name = "TEMPLATE")]
5758
pub chat_template: Option<String>,
@@ -111,7 +112,7 @@ impl MtmdCliContext {
111112
.chat_template(params.chat_template.as_deref())
112113
.map_err(|e| format!("Failed to get chat template: {e}"))?;
113114

114-
let batch = LlamaBatch::new(params.n_tokens, 1);
115+
let batch = LlamaBatch::new(params.n_tokens.get() as usize, 1);
115116

116117
Ok(Self {
117118
mtmd_ctx,
@@ -285,7 +286,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
285286
// Create context
286287
let context_params = LlamaContextParams::default()
287288
.with_n_threads(params.n_threads)
288-
.with_n_batch(1);
289+
.with_n_batch(1)
290+
.with_n_ctx(Some(params.n_tokens));
289291
let mut context = model.new_context(&backend, context_params)?;
290292

291293
// Create sampler

0 commit comments

Comments
 (0)