Skip to content

Commit d025465

Browse files
committed
Review round 1
* Remove unsafe Send, Sync * Cleanup error handling * Use default mtmd_context directly Signed-off-by: Dennis Keck <[email protected]>
1 parent e1f1e04 commit d025465

File tree

3 files changed

+81
-47
lines changed

3 files changed

+81
-47
lines changed

examples/mtmd/src/mtmd.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ use llama_cpp_2::context::params::LlamaContextParams;
1010
use llama_cpp_2::context::LlamaContext;
1111
use llama_cpp_2::llama_batch::LlamaBatch;
1212
use llama_cpp_2::model::params::LlamaModelParams;
13-
use llama_cpp_2::mtmd::{MtmdBitmap, MtmdBitmapError, MtmdContext, MtmdContextParams, MtmdInputText};
13+
use llama_cpp_2::mtmd::{
14+
MtmdBitmap, MtmdBitmapError, MtmdContext, MtmdContextParams, MtmdInputText,
15+
};
1416

1517
use llama_cpp_2::llama_backend::LlamaBackend;
1618
use llama_cpp_2::model::{LlamaChatMessage, LlamaChatTemplate, LlamaModel, Special};

llama-cpp-2/src/mtmd.rs

Lines changed: 76 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -77,28 +77,42 @@ pub struct MtmdContextParams {
7777

7878
impl Default for MtmdContextParams {
7979
fn default() -> Self {
80-
Self {
81-
use_gpu: false,
82-
print_timings: true,
83-
n_threads: 4,
84-
media_marker: CString::new(mtmd_default_marker()).unwrap_or_default(),
85-
}
80+
unsafe { llama_cpp_sys_2::mtmd_context_params_default() }.into()
8681
}
8782
}
8883

8984
impl From<&MtmdContextParams> for llama_cpp_sys_2::mtmd_context_params {
9085
fn from(params: &MtmdContextParams) -> Self {
9186
let mut context = unsafe { llama_cpp_sys_2::mtmd_context_params_default() };
92-
93-
context.use_gpu = params.use_gpu;
94-
context.print_timings = params.print_timings;
95-
context.n_threads = params.n_threads;
96-
context.media_marker = params.media_marker.as_ptr();
87+
let MtmdContextParams {
88+
use_gpu,
89+
print_timings,
90+
n_threads,
91+
media_marker,
92+
} = params;
93+
94+
context.use_gpu = *use_gpu;
95+
context.print_timings = *print_timings;
96+
context.n_threads = *n_threads;
97+
context.media_marker = media_marker.as_ptr();
9798

9899
context
99100
}
100101
}
101102

103+
impl From<llama_cpp_sys_2::mtmd_context_params> for MtmdContextParams {
104+
fn from(params: llama_cpp_sys_2::mtmd_context_params) -> Self {
105+
Self {
106+
use_gpu: params.use_gpu,
107+
print_timings: params.print_timings,
108+
n_threads: params.n_threads,
109+
media_marker: unsafe { CStr::from_ptr(params.media_marker) }
110+
.to_owned()
111+
.into(),
112+
}
113+
}
114+
}
115+
102116
/// Text input configuration
103117
///
104118
/// # Examples
@@ -165,40 +179,41 @@ impl MtmdContext {
165179
)
166180
};
167181

168-
if context.is_null() {
169-
return Err(MtmdInitError::NullResult);
170-
}
171-
172182
let context = NonNull::new(context).ok_or(MtmdInitError::NullResult)?;
173183
Ok(Self { context })
174184
}
175185

176186
/// Check whether non-causal attention mask is needed before `llama_decode`.
177-
#[must_use] pub fn decode_use_non_causal(&self) -> bool {
187+
#[must_use]
188+
pub fn decode_use_non_causal(&self) -> bool {
178189
unsafe { llama_cpp_sys_2::mtmd_decode_use_non_causal(self.context.as_ptr()) }
179190
}
180191

181192
/// Check whether the current model uses M-RoPE for `llama_decode`.
182193
///
183194
/// M-RoPE (Multimodal Rotary Position Embedding) affects how positions
184195
/// are calculated for multimodal inputs.
185-
#[must_use] pub fn decode_use_mrope(&self) -> bool {
196+
#[must_use]
197+
pub fn decode_use_mrope(&self) -> bool {
186198
unsafe { llama_cpp_sys_2::mtmd_decode_use_mrope(self.context.as_ptr()) }
187199
}
188200

189201
/// Check whether the current model supports vision input.
190-
#[must_use] pub fn support_vision(&self) -> bool {
202+
#[must_use]
203+
pub fn support_vision(&self) -> bool {
191204
unsafe { llama_cpp_sys_2::mtmd_support_vision(self.context.as_ptr()) }
192205
}
193206

194207
/// Check whether the current model supports audio input.
195-
#[must_use] pub fn support_audio(&self) -> bool {
208+
#[must_use]
209+
pub fn support_audio(&self) -> bool {
196210
unsafe { llama_cpp_sys_2::mtmd_support_audio(self.context.as_ptr()) }
197211
}
198212

199213
/// Get audio bitrate in Hz (e.g., 16000 for Whisper).
200214
/// Returns -1 if audio is not supported.
201-
#[must_use] pub fn get_audio_bitrate(&self) -> i32 {
215+
#[must_use]
216+
pub fn get_audio_bitrate(&self) -> i32 {
202217
unsafe { llama_cpp_sys_2::mtmd_get_audio_bitrate(self.context.as_ptr()) }
203218
}
204219

@@ -243,7 +258,7 @@ impl MtmdContext {
243258
bitmaps: &[&MtmdBitmap],
244259
) -> Result<MtmdInputChunks, MtmdTokenizeError> {
245260
let chunks = MtmdInputChunks::new();
246-
let text_cstring = CString::new(text.text).unwrap_or_default();
261+
let text_cstring = CString::new(text.text)?;
247262
let input_text = llama_cpp_sys_2::mtmd_input_text {
248263
text: text_cstring.as_ptr(),
249264
add_special: text.add_special,
@@ -304,9 +319,6 @@ impl MtmdContext {
304319
}
305320
}
306321

307-
unsafe impl Send for MtmdContext {}
308-
unsafe impl Sync for MtmdContext {}
309-
310322
impl Drop for MtmdContext {
311323
fn drop(&mut self) {
312324
unsafe { llama_cpp_sys_2::mtmd_free(self.context.as_ptr()) }
@@ -471,43 +483,48 @@ impl MtmdBitmap {
471483
}
472484

473485
/// Get bitmap width in pixels.
474-
#[must_use] pub fn nx(&self) -> u32 {
486+
#[must_use]
487+
pub fn nx(&self) -> u32 {
475488
unsafe { llama_cpp_sys_2::mtmd_bitmap_get_nx(self.bitmap.as_ptr()) }
476489
}
477490

478491
/// Get bitmap height in pixels.
479-
#[must_use] pub fn ny(&self) -> u32 {
492+
#[must_use]
493+
pub fn ny(&self) -> u32 {
480494
unsafe { llama_cpp_sys_2::mtmd_bitmap_get_ny(self.bitmap.as_ptr()) }
481495
}
482496

483497
/// Get bitmap data as a byte slice.
484498
///
485499
/// For images: RGB format with length `nx * ny * 3`
486500
/// For audio: PCM F32 format with length `n_samples * 4`
487-
#[must_use] pub fn data(&self) -> &[u8] {
501+
#[must_use]
502+
pub fn data(&self) -> &[u8] {
488503
let ptr = unsafe { llama_cpp_sys_2::mtmd_bitmap_get_data(self.bitmap.as_ptr()) };
489504
let len = unsafe { llama_cpp_sys_2::mtmd_bitmap_get_n_bytes(self.bitmap.as_ptr()) };
490505
unsafe { slice::from_raw_parts(ptr, len) }
491506
}
492507

493508
/// Check if this bitmap contains audio data (vs image data).
494-
#[must_use] pub fn is_audio(&self) -> bool {
509+
#[must_use]
510+
pub fn is_audio(&self) -> bool {
495511
unsafe { llama_cpp_sys_2::mtmd_bitmap_is_audio(self.bitmap.as_ptr()) }
496512
}
497513

498514
/// Get the bitmap's optional ID string.
499515
///
500516
/// Bitmap ID is useful for KV cache tracking and can e.g. be calculated
501517
/// based on a hash of the bitmap data.
502-
#[must_use] pub fn id(&self) -> Option<String> {
518+
#[must_use]
519+
pub fn id(&self) -> Option<String> {
503520
let ptr = unsafe { llama_cpp_sys_2::mtmd_bitmap_get_id(self.bitmap.as_ptr()) };
504521
if ptr.is_null() {
505522
None
506523
} else {
507-
unsafe { CStr::from_ptr(ptr) }
524+
let id = unsafe { CStr::from_ptr(ptr) }
508525
.to_string_lossy()
509-
.into_owned()
510-
.into()
526+
.into_owned();
527+
Some(id)
511528
}
512529
}
513530

@@ -580,24 +597,28 @@ impl MtmdInputChunks {
580597
/// assert_eq!(chunks.len(), 0);
581598
/// assert!(chunks.is_empty());
582599
/// ```
583-
#[must_use] pub fn new() -> Self {
600+
#[must_use]
601+
pub fn new() -> Self {
584602
let chunks = unsafe { llama_cpp_sys_2::mtmd_input_chunks_init() };
585603
let chunks = NonNull::new(chunks).unwrap();
586604
Self { chunks }
587605
}
588606

589607
/// Get the number of chunks
590-
#[must_use] pub fn len(&self) -> usize {
608+
#[must_use]
609+
pub fn len(&self) -> usize {
591610
unsafe { llama_cpp_sys_2::mtmd_input_chunks_size(self.chunks.as_ptr()) }
592611
}
593612

594613
/// Check if chunks collection is empty
595-
#[must_use] pub fn is_empty(&self) -> bool {
614+
#[must_use]
615+
pub fn is_empty(&self) -> bool {
596616
self.len() == 0
597617
}
598618

599619
/// Get a chunk by index
600-
#[must_use] pub fn get(&self, index: usize) -> Option<MtmdInputChunk> {
620+
#[must_use]
621+
pub fn get(&self, index: usize) -> Option<MtmdInputChunk> {
601622
if index >= self.len() {
602623
return None;
603624
}
@@ -619,15 +640,17 @@ impl MtmdInputChunks {
619640
/// Get total number of tokens across all chunks.
620641
///
621642
/// This is useful for keeping track of KV cache size.
622-
#[must_use] pub fn total_tokens(&self) -> usize {
643+
#[must_use]
644+
pub fn total_tokens(&self) -> usize {
623645
unsafe { llama_cpp_sys_2::mtmd_helper_get_n_tokens(self.chunks.as_ptr()) }
624646
}
625647

626648
/// Get total position count across all chunks.
627649
///
628650
/// This is useful to keep track of `n_past`. Normally `n_pos` equals `n_tokens`,
629651
/// but for M-RoPE it is different.
630-
#[must_use] pub fn total_positions(&self) -> i32 {
652+
#[must_use]
653+
pub fn total_positions(&self) -> i32 {
631654
unsafe { llama_cpp_sys_2::mtmd_helper_get_n_pos(self.chunks.as_ptr()) }
632655
}
633656

@@ -709,7 +732,8 @@ pub struct MtmdInputChunk {
709732

710733
impl MtmdInputChunk {
711734
/// Get the type of this chunk
712-
#[must_use] pub fn chunk_type(&self) -> MtmdInputChunkType {
735+
#[must_use]
736+
pub fn chunk_type(&self) -> MtmdInputChunkType {
713737
let chunk_type = unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_type(self.chunk.as_ptr()) };
714738
MtmdInputChunkType::from(chunk_type)
715739
}
@@ -721,7 +745,8 @@ impl MtmdInputChunk {
721745
/// # Returns
722746
///
723747
/// Returns `Some(&[LlamaToken])` for text chunks, `None` otherwise.
724-
#[must_use] pub fn text_tokens(&self) -> Option<&[LlamaToken]> {
748+
#[must_use]
749+
pub fn text_tokens(&self) -> Option<&[LlamaToken]> {
725750
if self.chunk_type() != MtmdInputChunkType::Text {
726751
return None;
727752
}
@@ -744,21 +769,24 @@ impl MtmdInputChunk {
744769
}
745770

746771
/// Get the number of tokens in this chunk
747-
#[must_use] pub fn n_tokens(&self) -> usize {
772+
#[must_use]
773+
pub fn n_tokens(&self) -> usize {
748774
unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_n_tokens(self.chunk.as_ptr()) }
749775
}
750776

751777
/// Get the number of positions in this chunk.
752778
///
753779
/// Returns the number of temporal positions (always 1 for M-RoPE, `n_tokens` otherwise).
754-
#[must_use] pub fn n_positions(&self) -> i32 {
780+
#[must_use]
781+
pub fn n_positions(&self) -> i32 {
755782
unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_n_pos(self.chunk.as_ptr()) }
756783
}
757784

758785
/// Get chunk ID if available.
759786
///
760787
/// Returns `None` for text chunks, may return an ID for image/audio chunks.
761-
#[must_use] pub fn id(&self) -> Option<String> {
788+
#[must_use]
789+
pub fn id(&self) -> Option<String> {
762790
let ptr = unsafe { llama_cpp_sys_2::mtmd_input_chunk_get_id(self.chunk.as_ptr()) };
763791
if ptr.is_null() {
764792
None
@@ -819,7 +847,8 @@ impl Drop for MtmdInputChunk {
819847
/// let text = format!("Describe this image: {}", marker);
820848
/// assert!(text.contains(marker));
821849
/// ```
822-
#[must_use] pub fn mtmd_default_marker() -> &'static str {
850+
#[must_use]
851+
pub fn mtmd_default_marker() -> &'static str {
823852
unsafe {
824853
let c_str = llama_cpp_sys_2::mtmd_default_marker();
825854
CStr::from_ptr(c_str).to_str().unwrap_or("<__media__>")
@@ -877,6 +906,9 @@ pub enum MtmdTokenizeError {
877906
/// Image preprocessing error occurred
878907
#[error("Image preprocessing error")]
879908
ImagePreprocessingError,
909+
/// Text contains characters that cannot be converted to C string
910+
#[error("Failed to create CString from text: {0}")]
911+
CStringError(#[from] std::ffi::NulError),
880912
/// Unknown error occurred during tokenization
881913
#[error("Unknown error: {0}")]
882914
UnknownError(i32),

llama-cpp-sys-2/build.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,8 @@ fn main() {
277277
// Configure mtmd feature if enabled
278278
if cfg!(feature = "mtmd") {
279279
bindings_builder = bindings_builder
280-
.allowlist_function("mtmd_.*")
281-
.allowlist_type("mtmd_.*");
280+
.allowlist_function("mtmd_.*")
281+
.allowlist_type("mtmd_.*");
282282
}
283283

284284
// Configure Android-specific bindgen settings

0 commit comments

Comments
 (0)