Skip to content

Commit 96b6bcc

Browse files
authored
Merge pull request #790 from fellhorn/dennis/feat/multi-modal
Multimodality support
2 parents 8f5d526 + f149f11 commit 96b6bcc

File tree

12 files changed

+1322
-4
lines changed

12 files changed

+1322
-4
lines changed

Cargo.lock

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ members = [
66
"examples/embeddings",
77
"examples/simple",
88
"examples/reranker",
9+
"examples/mtmd",
910
]
1011

1112
[workspace.dependencies]

examples/mtmd/Cargo.toml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
[package]
2+
name = "mtmd"
3+
version = "0.1.86"
4+
edition = "2021"
5+
6+
[dependencies]
7+
llama-cpp-2 = { path = "../../llama-cpp-2", version = "0.1.86", features = ["mtmd"] }
8+
clap = { workspace = true, features = ["derive"] }
9+
10+
[features]
11+
cuda = ["llama-cpp-2/cuda"]
12+
metal = ["llama-cpp-2/metal"]
13+
native = ["llama-cpp-2/native"]
14+
vulkan = ["llama-cpp-2/vulkan"]
15+
16+
[lints]
17+
workspace = true
18+
19+
[[example]]
20+
name = "mtmd"
21+
path = "src/mtmd.rs"

examples/mtmd/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Rust mtmd-cli implementation
2+
3+
Partial port of the mtmd-cli.cpp example in the llama-cpp repository.
4+
5+
## Usage
6+
7+
### Command Line Interface
8+
9+
To run the mtmd example, you first need to download the model gguf file and the multimodal projection file, e.g. for Gemma3 you may use:
10+
11+
```sh
12+
wget https://huggingface.co/unsloth/gemma-3-4b-it-GGUF/resolve/main/gemma-3-4b-it-Q4_K_M.gguf \
13+
https://huggingface.co/unsloth/gemma-3-4b-it-GGUF/resolve/main/mmproj-F16.gguf
14+
```
15+
16+
To then run the example on CPU, provide an image file `my_image.jpg` and run:
17+
18+
```sh
19+
cargo run --release --example mtmd -- \
20+
--model ./gemma-3-4b-it-Q4_K_M.gguf \
21+
--mmproj ./mmproj-F16.gguf \
22+
--image my_image.jpg \
23+
--prompt "What is in the picture?" \
24+
--no-gpu \
25+
--no-mmproj-offload \
26+
--marker "<start_of_image>"
27+
```

examples/mtmd/src/mtmd.rs

Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
//! Based on the mtmd cli example from llama.cpp.
2+
3+
use std::ffi::CString;
4+
use std::io::{self, Write};
5+
use std::num::NonZeroU32;
6+
use std::path::Path;
7+
8+
use clap::Parser;
9+
10+
use llama_cpp_2::context::params::LlamaContextParams;
11+
use llama_cpp_2::context::LlamaContext;
12+
use llama_cpp_2::llama_batch::LlamaBatch;
13+
use llama_cpp_2::model::params::LlamaModelParams;
14+
use llama_cpp_2::mtmd::{
15+
MtmdBitmap, MtmdBitmapError, MtmdContext, MtmdContextParams, MtmdInputText,
16+
};
17+
18+
use llama_cpp_2::llama_backend::LlamaBackend;
19+
use llama_cpp_2::model::{LlamaChatMessage, LlamaChatTemplate, LlamaModel, Special};
20+
use llama_cpp_2::sampling::LlamaSampler;
21+
22+
/// Command line parameters for the MTMD CLI application
23+
#[derive(clap::Parser, Debug)]
24+
#[command(name = "mtmd-cli")]
25+
#[command(about = "Experimental CLI for multimodal llama.cpp")]
26+
pub struct MtmdCliParams {
27+
/// Path to the model file
28+
#[arg(short = 'm', long = "model", value_name = "PATH")]
29+
pub model_path: String,
30+
/// Path to the multimodal projection file
31+
#[arg(long = "mmproj", value_name = "PATH")]
32+
pub mmproj_path: String,
33+
/// Path to image file(s)
34+
#[arg(long = "image", value_name = "PATH")]
35+
pub images: Vec<String>,
36+
/// Path to audio file(s)
37+
#[arg(long = "audio", value_name = "PATH")]
38+
pub audio: Vec<String>,
39+
/// Text prompt to use as input to the model. May include media markers - else they will be added automatically.
40+
#[arg(short = 'p', long = "prompt", value_name = "TEXT")]
41+
pub prompt: String,
42+
/// Number of tokens to predict (-1 for unlimited)
43+
#[arg(
44+
short = 'n',
45+
long = "n-predict",
46+
value_name = "N",
47+
default_value = "-1"
48+
)]
49+
pub n_predict: i32,
50+
/// Number of threads
51+
#[arg(short = 't', long = "threads", value_name = "N", default_value = "4")]
52+
pub n_threads: i32,
53+
/// Maximum number of tokens in context
54+
#[arg(long = "n-tokens", value_name = "N", default_value = "4096")]
55+
pub n_tokens: NonZeroU32,
56+
/// Chat template to use, default template if not provided
57+
#[arg(long = "chat-template", value_name = "TEMPLATE")]
58+
pub chat_template: Option<String>,
59+
/// Disable GPU acceleration
60+
#[arg(long = "no-gpu")]
61+
pub no_gpu: bool,
62+
/// Disable GPU offload for multimodal projection
63+
#[arg(long = "no-mmproj-offload")]
64+
pub no_mmproj_offload: bool,
65+
/// Media marker. If not provided, the default marker will be used.
66+
#[arg(long = "marker", value_name = "TEXT")]
67+
pub media_marker: Option<String>,
68+
}
69+
70+
/// State of the MTMD CLI application.
71+
#[allow(missing_debug_implementations)]
72+
pub struct MtmdCliContext {
73+
/// The MTMD context for multimodal processing.
74+
pub mtmd_ctx: MtmdContext,
75+
/// The batch used for processing tokens.
76+
pub batch: LlamaBatch,
77+
/// The list of loaded bitmaps (images/audio).
78+
pub bitmaps: Vec<MtmdBitmap>,
79+
/// The number of past tokens processed.
80+
pub n_past: i32,
81+
/// The chat template used for formatting messages.
82+
pub chat_template: LlamaChatTemplate,
83+
/// The current chat messages history.
84+
pub chat: Vec<LlamaChatMessage>,
85+
}
86+
87+
impl MtmdCliContext {
88+
/// Creates a new MTMD CLI context
89+
///
90+
/// # Errors
91+
pub fn new(
92+
params: &MtmdCliParams,
93+
model: &LlamaModel,
94+
) -> Result<Self, Box<dyn std::error::Error>> {
95+
// Initialize MTMD context
96+
let mtmd_params = MtmdContextParams {
97+
use_gpu: !params.no_gpu && !params.no_mmproj_offload,
98+
print_timings: true,
99+
n_threads: params.n_threads,
100+
media_marker: CString::new(
101+
params
102+
.media_marker
103+
.as_ref()
104+
.unwrap_or(&llama_cpp_2::mtmd::mtmd_default_marker().to_string())
105+
.clone(),
106+
)?,
107+
};
108+
109+
let mtmd_ctx = MtmdContext::init_from_file(&params.mmproj_path, model, &mtmd_params)?;
110+
111+
let chat_template = model
112+
.chat_template(params.chat_template.as_deref())
113+
.map_err(|e| format!("Failed to get chat template: {e}"))?;
114+
115+
let batch = LlamaBatch::new(params.n_tokens.get() as usize, 1);
116+
117+
Ok(Self {
118+
mtmd_ctx,
119+
batch,
120+
chat: Vec::new(),
121+
bitmaps: Vec::new(),
122+
n_past: 0,
123+
chat_template,
124+
})
125+
}
126+
127+
/// Loads media (image or audio) from the specified file path
128+
/// # Errors
129+
pub fn load_media(&mut self, path: &str) -> Result<(), MtmdBitmapError> {
130+
let bitmap = MtmdBitmap::from_file(&self.mtmd_ctx, path)?;
131+
self.bitmaps.push(bitmap);
132+
Ok(())
133+
}
134+
135+
/// Evaluates a chat message, tokenizing and processing it through the model
136+
/// # Errors
137+
pub fn eval_message(
138+
&mut self,
139+
model: &LlamaModel,
140+
context: &mut LlamaContext,
141+
msg: LlamaChatMessage,
142+
add_bos: bool,
143+
) -> Result<(), Box<dyn std::error::Error>> {
144+
self.chat.push(msg);
145+
146+
// Format the message using chat template (simplified)
147+
let formatted_prompt = model.apply_chat_template(&self.chat_template, &self.chat, true)?;
148+
149+
let input_text = MtmdInputText {
150+
text: formatted_prompt,
151+
add_special: add_bos,
152+
parse_special: true,
153+
};
154+
155+
let bitmap_refs: Vec<&MtmdBitmap> = self.bitmaps.iter().collect();
156+
157+
if bitmap_refs.is_empty() {
158+
println!("No bitmaps provided, only tokenizing text");
159+
} else {
160+
println!("Tokenizing with {} bitmaps", bitmap_refs.len());
161+
}
162+
163+
// Tokenize the input
164+
let chunks = self.mtmd_ctx.tokenize(input_text, &bitmap_refs)?;
165+
166+
println!("Tokenization complete, {} chunks created", chunks.len());
167+
168+
// Clear bitmaps after tokenization
169+
self.bitmaps.clear();
170+
171+
self.n_past = chunks.eval_chunks(&self.mtmd_ctx, context, 0, 0, 1, true)?;
172+
Ok(())
173+
}
174+
175+
/// Generates a response by sampling tokens from the model
176+
/// # Errors
177+
pub fn generate_response(
178+
&mut self,
179+
model: &LlamaModel,
180+
context: &mut LlamaContext,
181+
sampler: &mut LlamaSampler,
182+
n_predict: i32,
183+
) -> Result<(), Box<dyn std::error::Error>> {
184+
let mut generated_tokens = Vec::new();
185+
let max_predict = if n_predict < 0 { i32::MAX } else { n_predict };
186+
187+
for _i in 0..max_predict {
188+
// Sample next token
189+
let token = sampler.sample(context, 0);
190+
generated_tokens.push(token);
191+
sampler.accept(token);
192+
193+
// Check for end of generation
194+
if model.is_eog_token(token) {
195+
println!();
196+
break;
197+
}
198+
199+
// Print token
200+
let piece = model.token_to_str(token, Special::Tokenize)?;
201+
print!("{piece}");
202+
io::stdout().flush()?;
203+
204+
// Prepare next batch
205+
self.batch.clear();
206+
self.batch.add(token, self.n_past, &[0], true)?;
207+
self.n_past += 1;
208+
209+
// Decode
210+
context.decode(&mut self.batch)?;
211+
}
212+
213+
Ok(())
214+
}
215+
}
216+
217+
fn run_single_turn(
218+
ctx: &mut MtmdCliContext,
219+
model: &LlamaModel,
220+
context: &mut LlamaContext,
221+
sampler: &mut LlamaSampler,
222+
params: &MtmdCliParams,
223+
) -> Result<(), Box<dyn std::error::Error>> {
224+
// Add media marker if not present
225+
let mut prompt = params.prompt.clone();
226+
let default_marker = llama_cpp_2::mtmd::mtmd_default_marker().to_string();
227+
let media_marker = params.media_marker.as_ref().unwrap_or(&default_marker);
228+
if !prompt.contains(media_marker) {
229+
prompt.push_str(media_marker);
230+
}
231+
232+
// Load media files
233+
for image_path in &params.images {
234+
println!("Loading image: {image_path}");
235+
ctx.load_media(image_path)?;
236+
}
237+
for audio_path in &params.audio {
238+
ctx.load_media(audio_path)?;
239+
}
240+
241+
// Create user message
242+
let msg = LlamaChatMessage::new("user".to_string(), prompt)?;
243+
244+
println!("Evaluating message: {msg:?}");
245+
246+
// Evaluate the message (prefill)
247+
ctx.eval_message(model, context, msg, true)?;
248+
249+
// Generate response (decode)
250+
ctx.generate_response(model, context, sampler, params.n_predict)?;
251+
252+
Ok(())
253+
}
254+
255+
fn main() -> Result<(), Box<dyn std::error::Error>> {
256+
let params = MtmdCliParams::parse();
257+
258+
// Validate required parameters
259+
if !Path::new(&params.model_path).exists() {
260+
eprintln!("Error: Model file not found: {}", params.model_path);
261+
return Err("Model file not found".into());
262+
}
263+
264+
if !Path::new(&params.mmproj_path).exists() {
265+
eprintln!(
266+
"Error: Multimodal projection file not found: {}",
267+
params.mmproj_path
268+
);
269+
return Err("Multimodal projection file not found".into());
270+
}
271+
272+
println!("Loading model: {}", params.model_path);
273+
274+
// Initialize backend
275+
let backend = LlamaBackend::init()?;
276+
277+
// Setup model parameters
278+
let mut model_params = LlamaModelParams::default();
279+
if !params.no_gpu {
280+
model_params = model_params.with_n_gpu_layers(1_000_000); // Use all layers on GPU
281+
}
282+
283+
// Load model
284+
let model = LlamaModel::load_from_file(&backend, &params.model_path, &model_params)?;
285+
286+
// Create context
287+
let context_params = LlamaContextParams::default()
288+
.with_n_threads(params.n_threads)
289+
.with_n_batch(1)
290+
.with_n_ctx(Some(params.n_tokens));
291+
let mut context = model.new_context(&backend, context_params)?;
292+
293+
// Create sampler
294+
let mut sampler = LlamaSampler::chain_simple([LlamaSampler::greedy()]);
295+
296+
println!("Model loaded successfully");
297+
println!("Loading mtmd projection: {}", params.mmproj_path);
298+
299+
// Create the MTMD context
300+
let mut ctx = MtmdCliContext::new(&params, &model)?;
301+
302+
run_single_turn(&mut ctx, &model, &mut context, &mut sampler, &params)?;
303+
304+
println!("\n");
305+
306+
Ok(())
307+
}

0 commit comments

Comments
 (0)