Skip to content

Commit 9454897

Browse files
committed
WIP
Signed-off-by: Dennis Keck <[email protected]>
1 parent c07bda5 commit 9454897

File tree

15 files changed

+965
-65
lines changed

15 files changed

+965
-65
lines changed

Cargo.lock

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ members = [
66
"examples/embeddings",
77
"examples/simple",
88
"examples/reranker",
9+
"examples/mtmd",
910
]
1011

1112
[workspace.dependencies]

examples/mtmd/Cargo.toml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
[package]
2+
name = "mtmd"
3+
version = "0.1.86"
4+
edition = "2021"
5+
6+
[dependencies]
7+
llama-cpp-2 = { path = "../../llama-cpp-2", version = "0.1.86" }
8+
clap = { workspace = true, features = ["derive"] }
9+
# hf-hub = { workspace = true }
10+
# anyhow = { workspace = true }
11+
# encoding_rs = { workspace = true }
12+
13+
[features]
14+
cuda = ["llama-cpp-2/cuda"]
15+
metal = ["llama-cpp-2/metal"]
16+
native = ["llama-cpp-2/native"]
17+
vulkan = ["llama-cpp-2/vulkan"]
18+
19+
[lints]
20+
workspace = true
21+
22+
[[example]]
23+
name = "mtmd"
24+
path = "src/mtmd.rs"

examples/mtmd/src/mtmd.rs

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
//! Based on the mtmd cli example from llama.cpp.
2+
3+
use std::ffi::CString;
4+
use std::io::{self, Write};
5+
use std::path::Path;
6+
7+
use clap::Parser;
8+
9+
use llama_cpp_2::context::params::LlamaContextParams;
10+
use llama_cpp_2::context::LlamaContext;
11+
use llama_cpp_2::llama_batch::LlamaBatch;
12+
use llama_cpp_2::model::params::LlamaModelParams;
13+
use llama_cpp_2::mtmd::*;
14+
15+
use llama_cpp_2::llama_backend::LlamaBackend;
16+
use llama_cpp_2::model::{LlamaChatMessage, LlamaChatTemplate, LlamaModel, Special};
17+
use llama_cpp_2::sampling::LlamaSampler;
18+
19+
/// Command line parameters for the MTMD CLI application
20+
#[derive(clap::Parser, Debug)]
21+
#[command(name = "mtmd-cli")]
22+
#[command(about = "Experimental CLI for multimodal llama.cpp")]
23+
pub struct MtmdCliParams {
24+
/// Path to the model file
25+
#[arg(short = 'm', long = "model", value_name = "PATH")]
26+
pub model_path: String,
27+
/// Path to the multimodal projection file
28+
#[arg(long = "mmproj", value_name = "PATH")]
29+
pub mmproj_path: String,
30+
/// Path to image file(s)
31+
#[arg(long = "image", value_name = "PATH")]
32+
pub images: Vec<String>,
33+
/// Path to audio file(s)
34+
#[arg(long = "audio", value_name = "PATH")]
35+
pub audio: Vec<String>,
36+
/// Text prompt to use as input to the model. May include media markers - else they will be added automatically.
37+
#[arg(short = 'p', long = "prompt", value_name = "TEXT")]
38+
pub prompt: String,
39+
/// Number of tokens to predict (-1 for unlimited)
40+
#[arg(
41+
short = 'n',
42+
long = "n-predict",
43+
value_name = "N",
44+
default_value = "-1"
45+
)]
46+
pub n_predict: i32,
47+
/// Number of threads
48+
#[arg(short = 't', long = "threads", value_name = "N", default_value = "4")]
49+
pub n_threads: i32,
50+
/// Maximum number of tokens in context
51+
#[arg(long = "n-tokens", value_name = "N", default_value = "2048")]
52+
pub n_tokens: usize,
53+
/// Chat template to use, default template if not provided
54+
#[arg(long = "chat-template", value_name = "TEMPLATE")]
55+
pub chat_template: Option<String>,
56+
/// Disable GPU acceleration
57+
#[arg(long = "no-gpu")]
58+
pub no_gpu: bool,
59+
/// Disable GPU offload for multimodal projection
60+
#[arg(long = "no-mmproj-offload")]
61+
pub no_mmproj_offload: bool,
62+
/// Media marker. If not provided, the default marker will be used.
63+
#[arg(long = "marker", value_name = "TEXT")]
64+
pub media_marker: Option<String>,
65+
}
66+
67+
/// State of the MTMD CLI application.
68+
#[allow(missing_debug_implementations)]
69+
pub struct MtmdCliContext {
70+
/// The MTMD context for multimodal processing.
71+
pub mtmd_ctx: MtmdContext,
72+
/// The batch used for processing tokens.
73+
pub batch: LlamaBatch,
74+
/// The list of loaded bitmaps (images/audio).
75+
pub bitmaps: Vec<MtmdBitmap>,
76+
/// The number of past tokens processed.
77+
pub n_past: i32,
78+
/// The chat template used for formatting messages.
79+
pub chat_template: LlamaChatTemplate,
80+
/// The current chat messages history.
81+
pub chat: Vec<LlamaChatMessage>,
82+
}
83+
84+
impl MtmdCliContext {
85+
/// Creates a new MTMD CLI context
86+
pub fn new(
87+
params: &MtmdCliParams,
88+
model: &LlamaModel,
89+
) -> Result<Self, Box<dyn std::error::Error>> {
90+
// Initialize MTMD context
91+
let mtmd_params = MtmdContextParams {
92+
use_gpu: !params.no_gpu && !params.no_mmproj_offload,
93+
print_timings: true,
94+
n_threads: params.n_threads,
95+
media_marker: CString::new(
96+
params
97+
.media_marker
98+
.as_ref()
99+
.unwrap_or(&llama_cpp_2::mtmd::mtmd_default_marker().to_string())
100+
.clone(),
101+
)?,
102+
};
103+
104+
let mtmd_ctx = MtmdContext::init_from_file(&params.mmproj_path, model, mtmd_params)?;
105+
106+
let chat_template = model
107+
.chat_template(params.chat_template.as_deref())
108+
.map_err(|e| format!("Failed to get chat template: {}", e))?;
109+
110+
let batch = LlamaBatch::new(params.n_tokens, 1);
111+
112+
Ok(Self {
113+
mtmd_ctx,
114+
batch,
115+
chat: Vec::new(),
116+
bitmaps: Vec::new(),
117+
n_past: 0,
118+
chat_template,
119+
})
120+
}
121+
122+
/// Loads media (image or audio) from the specified file path
123+
pub fn load_media(&mut self, path: &str) -> Result<(), MtmdBitmapError> {
124+
let bitmap = MtmdBitmap::from_file(&self.mtmd_ctx, path)?;
125+
self.bitmaps.push(bitmap);
126+
Ok(())
127+
}
128+
129+
/// Evaluates a chat message, tokenizing and processing it through the model
130+
pub fn eval_message(
131+
&mut self,
132+
model: &LlamaModel,
133+
context: &mut LlamaContext,
134+
msg: LlamaChatMessage,
135+
add_bos: bool,
136+
) -> Result<(), Box<dyn std::error::Error>> {
137+
self.chat.push(msg);
138+
139+
// Format the message using chat template (simplified)
140+
let formatted_prompt = model.apply_chat_template(&self.chat_template, &self.chat, true)?;
141+
142+
let input_text = MtmdInputText {
143+
text: formatted_prompt,
144+
add_special: add_bos,
145+
parse_special: true,
146+
};
147+
148+
let bitmap_refs: Vec<&MtmdBitmap> = self.bitmaps.iter().collect();
149+
150+
if bitmap_refs.is_empty() {
151+
println!("No bitmaps provided, only tokenizing text");
152+
} else {
153+
println!("Tokenizing with {} bitmaps", bitmap_refs.len());
154+
}
155+
156+
// Tokenize the input
157+
let chunks = self.mtmd_ctx.tokenize(input_text, &bitmap_refs)?;
158+
159+
println!("Tokenization complete, {} chunks created", chunks.len());
160+
161+
// Clear bitmaps after tokenization
162+
self.bitmaps.clear();
163+
164+
self.n_past = chunks.eval_chunks(&self.mtmd_ctx, &context, 0, 0, 1, true)?;
165+
Ok(())
166+
}
167+
168+
/// Generates a response by sampling tokens from the model
169+
pub fn generate_response(
170+
&mut self,
171+
model: &LlamaModel,
172+
context: &mut LlamaContext,
173+
sampler: &mut LlamaSampler,
174+
n_predict: i32,
175+
) -> Result<(), Box<dyn std::error::Error>> {
176+
let mut generated_tokens = Vec::new();
177+
let max_predict = if n_predict < 0 { i32::MAX } else { n_predict };
178+
179+
for _i in 0..max_predict {
180+
// Sample next token
181+
let token = sampler.sample(context, 0);
182+
generated_tokens.push(token);
183+
sampler.accept(token);
184+
185+
// Check for end of generation
186+
if model.is_eog_token(token) {
187+
println!();
188+
break;
189+
}
190+
191+
// Print token
192+
let piece = model.token_to_str(token, Special::Tokenize)?;
193+
print!("{}", piece);
194+
io::stdout().flush()?;
195+
196+
// Prepare next batch
197+
self.batch.clear();
198+
self.batch.add(token, self.n_past, &[0], true)?;
199+
self.n_past += 1;
200+
201+
// Decode
202+
context.decode(&mut self.batch)?;
203+
}
204+
205+
Ok(())
206+
}
207+
}
208+
209+
fn run_single_turn(
210+
ctx: &mut MtmdCliContext,
211+
model: &LlamaModel,
212+
context: &mut LlamaContext,
213+
sampler: &mut LlamaSampler,
214+
params: &MtmdCliParams,
215+
) -> Result<(), Box<dyn std::error::Error>> {
216+
// Add media marker if not present
217+
let mut prompt = params.prompt.clone();
218+
let default_marker = llama_cpp_2::mtmd::mtmd_default_marker().to_string();
219+
let media_marker = params.media_marker.as_ref().unwrap_or(&default_marker);
220+
if !prompt.contains(media_marker) {
221+
prompt.push_str(media_marker);
222+
}
223+
224+
// Load media files
225+
for image_path in &params.images {
226+
println!("Loading image: {}", image_path);
227+
ctx.load_media(image_path)?;
228+
}
229+
for audio_path in &params.audio {
230+
ctx.load_media(audio_path)?;
231+
}
232+
233+
// Create user message
234+
let msg = LlamaChatMessage::new("user".to_string(), prompt)?;
235+
236+
println!("Evaluating message: {:?}", msg);
237+
238+
// Evaluate the message (prefill)
239+
ctx.eval_message(model, context, msg, true)?;
240+
241+
// Generate response (decode)
242+
ctx.generate_response(model, context, sampler, params.n_predict)?;
243+
244+
Ok(())
245+
}
246+
247+
fn main() -> Result<(), Box<dyn std::error::Error>> {
248+
let params = MtmdCliParams::parse();
249+
250+
// Validate required parameters
251+
if !Path::new(&params.model_path).exists() {
252+
eprintln!("Error: Model file not found: {}", params.model_path);
253+
return Err("Model file not found".into());
254+
}
255+
256+
if !Path::new(&params.mmproj_path).exists() {
257+
eprintln!(
258+
"Error: Multimodal projection file not found: {}",
259+
params.mmproj_path
260+
);
261+
return Err("Multimodal projection file not found".into());
262+
}
263+
264+
println!("Loading model: {}", params.model_path);
265+
266+
// Initialize backend
267+
let backend = LlamaBackend::init()?;
268+
269+
// Setup model parameters
270+
let mut model_params = LlamaModelParams::default();
271+
if !params.no_gpu {
272+
model_params = model_params.with_n_gpu_layers(1000000); // Use all layers on GPU
273+
}
274+
275+
// Load model
276+
let model = LlamaModel::load_from_file(&backend, &params.model_path, &model_params)?;
277+
278+
// Create context
279+
let context_params = LlamaContextParams::default()
280+
.with_n_threads(params.n_threads)
281+
.with_n_batch(1);
282+
let mut context = model.new_context(&backend, context_params)?;
283+
284+
// Create sampler
285+
let mut sampler = LlamaSampler::chain_simple([LlamaSampler::greedy()]);
286+
287+
println!("Model loaded successfully");
288+
println!("Loading mtmd projection: {}", params.mmproj_path);
289+
290+
// Create the MTMD context
291+
let mut ctx = MtmdCliContext::new(&params, &model)?;
292+
293+
run_single_turn(&mut ctx, &model, &mut context, &mut sampler, &params)?;
294+
295+
println!("\n");
296+
297+
Ok(())
298+
}

examples/reranker/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ native = ["llama-cpp-2/native"]
1717
vulkan = ["llama-cpp-2/vulkan"]
1818

1919
[lints]
20-
workspace = true
20+
workspace = true

llama-cpp-2/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pub mod llama_backend;
2727
pub mod llama_batch;
2828
mod log;
2929
pub mod model;
30+
pub mod mtmd;
3031
pub mod sampling;
3132
pub mod timing;
3233
pub mod token;

llama-cpp-2/src/model.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ use crate::model::params::LlamaModelParams;
1313
use crate::token::LlamaToken;
1414
use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
1515
use crate::{
16-
ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError,
17-
LlamaLoraAdapterInitError, LlamaModelLoadError, MetaValError, NewLlamaChatMessageError,
18-
StringToTokenError, TokenToStringError,
16+
ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
17+
LlamaModelLoadError, MetaValError, NewLlamaChatMessageError, StringToTokenError,
18+
TokenToStringError,
1919
};
2020

2121
pub mod params;
@@ -488,7 +488,8 @@ impl LlamaModel {
488488
pub fn n_head_kv(&self) -> u32 {
489489
// It's never possible for this to panic because while the API interface is defined as an int32_t,
490490
// the field it's accessing is a uint32_t.
491-
u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head_kv(self.model.as_ptr()) }).unwrap()
491+
u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head_kv(self.model.as_ptr()) })
492+
.unwrap()
492493
}
493494

494495
/// Get metadata value as a string by key name

0 commit comments

Comments
 (0)