@@ -77,28 +77,42 @@ pub struct MtmdContextParams {
77
77
78
78
impl Default for MtmdContextParams {
79
79
fn default ( ) -> Self {
80
- Self {
81
- use_gpu : false ,
82
- print_timings : true ,
83
- n_threads : 4 ,
84
- media_marker : CString :: new ( mtmd_default_marker ( ) ) . unwrap_or_default ( ) ,
85
- }
80
+ unsafe { llama_cpp_sys_2:: mtmd_context_params_default ( ) } . into ( )
86
81
}
87
82
}
88
83
89
84
impl From < & MtmdContextParams > for llama_cpp_sys_2:: mtmd_context_params {
90
85
fn from ( params : & MtmdContextParams ) -> Self {
91
86
let mut context = unsafe { llama_cpp_sys_2:: mtmd_context_params_default ( ) } ;
92
-
93
- context. use_gpu = params. use_gpu ;
94
- context. print_timings = params. print_timings ;
95
- context. n_threads = params. n_threads ;
96
- context. media_marker = params. media_marker . as_ptr ( ) ;
87
+ let MtmdContextParams {
88
+ use_gpu,
89
+ print_timings,
90
+ n_threads,
91
+ media_marker,
92
+ } = params;
93
+
94
+ context. use_gpu = * use_gpu;
95
+ context. print_timings = * print_timings;
96
+ context. n_threads = * n_threads;
97
+ context. media_marker = media_marker. as_ptr ( ) ;
97
98
98
99
context
99
100
}
100
101
}
101
102
103
+ impl From < llama_cpp_sys_2:: mtmd_context_params > for MtmdContextParams {
104
+ fn from ( params : llama_cpp_sys_2:: mtmd_context_params ) -> Self {
105
+ Self {
106
+ use_gpu : params. use_gpu ,
107
+ print_timings : params. print_timings ,
108
+ n_threads : params. n_threads ,
109
+ media_marker : unsafe { CStr :: from_ptr ( params. media_marker ) }
110
+ . to_owned ( )
111
+ . into ( ) ,
112
+ }
113
+ }
114
+ }
115
+
102
116
/// Text input configuration
103
117
///
104
118
/// # Examples
@@ -165,40 +179,41 @@ impl MtmdContext {
165
179
)
166
180
} ;
167
181
168
- if context. is_null ( ) {
169
- return Err ( MtmdInitError :: NullResult ) ;
170
- }
171
-
172
182
let context = NonNull :: new ( context) . ok_or ( MtmdInitError :: NullResult ) ?;
173
183
Ok ( Self { context } )
174
184
}
175
185
176
186
/// Check whether non-causal attention mask is needed before `llama_decode`.
177
- #[ must_use] pub fn decode_use_non_causal ( & self ) -> bool {
187
+ #[ must_use]
188
+ pub fn decode_use_non_causal ( & self ) -> bool {
178
189
unsafe { llama_cpp_sys_2:: mtmd_decode_use_non_causal ( self . context . as_ptr ( ) ) }
179
190
}
180
191
181
192
/// Check whether the current model uses M-RoPE for `llama_decode`.
182
193
///
183
194
/// M-RoPE (Multimodal Rotary Position Embedding) affects how positions
184
195
/// are calculated for multimodal inputs.
185
- #[ must_use] pub fn decode_use_mrope ( & self ) -> bool {
196
+ #[ must_use]
197
+ pub fn decode_use_mrope ( & self ) -> bool {
186
198
unsafe { llama_cpp_sys_2:: mtmd_decode_use_mrope ( self . context . as_ptr ( ) ) }
187
199
}
188
200
189
201
/// Check whether the current model supports vision input.
190
- #[ must_use] pub fn support_vision ( & self ) -> bool {
202
+ #[ must_use]
203
+ pub fn support_vision ( & self ) -> bool {
191
204
unsafe { llama_cpp_sys_2:: mtmd_support_vision ( self . context . as_ptr ( ) ) }
192
205
}
193
206
194
207
/// Check whether the current model supports audio input.
195
- #[ must_use] pub fn support_audio ( & self ) -> bool {
208
+ #[ must_use]
209
+ pub fn support_audio ( & self ) -> bool {
196
210
unsafe { llama_cpp_sys_2:: mtmd_support_audio ( self . context . as_ptr ( ) ) }
197
211
}
198
212
199
213
/// Get audio bitrate in Hz (e.g., 16000 for Whisper).
200
214
/// Returns -1 if audio is not supported.
201
- #[ must_use] pub fn get_audio_bitrate ( & self ) -> i32 {
215
+ #[ must_use]
216
+ pub fn get_audio_bitrate ( & self ) -> i32 {
202
217
unsafe { llama_cpp_sys_2:: mtmd_get_audio_bitrate ( self . context . as_ptr ( ) ) }
203
218
}
204
219
@@ -243,7 +258,7 @@ impl MtmdContext {
243
258
bitmaps : & [ & MtmdBitmap ] ,
244
259
) -> Result < MtmdInputChunks , MtmdTokenizeError > {
245
260
let chunks = MtmdInputChunks :: new ( ) ;
246
- let text_cstring = CString :: new ( text. text ) . unwrap_or_default ( ) ;
261
+ let text_cstring = CString :: new ( text. text ) ? ;
247
262
let input_text = llama_cpp_sys_2:: mtmd_input_text {
248
263
text : text_cstring. as_ptr ( ) ,
249
264
add_special : text. add_special ,
@@ -304,9 +319,6 @@ impl MtmdContext {
304
319
}
305
320
}
306
321
307
- unsafe impl Send for MtmdContext { }
308
- unsafe impl Sync for MtmdContext { }
309
-
310
322
impl Drop for MtmdContext {
311
323
fn drop ( & mut self ) {
312
324
unsafe { llama_cpp_sys_2:: mtmd_free ( self . context . as_ptr ( ) ) }
@@ -471,43 +483,48 @@ impl MtmdBitmap {
471
483
}
472
484
473
485
/// Get bitmap width in pixels.
474
- #[ must_use] pub fn nx ( & self ) -> u32 {
486
+ #[ must_use]
487
+ pub fn nx ( & self ) -> u32 {
475
488
unsafe { llama_cpp_sys_2:: mtmd_bitmap_get_nx ( self . bitmap . as_ptr ( ) ) }
476
489
}
477
490
478
491
/// Get bitmap height in pixels.
479
- #[ must_use] pub fn ny ( & self ) -> u32 {
492
+ #[ must_use]
493
+ pub fn ny ( & self ) -> u32 {
480
494
unsafe { llama_cpp_sys_2:: mtmd_bitmap_get_ny ( self . bitmap . as_ptr ( ) ) }
481
495
}
482
496
483
497
/// Get bitmap data as a byte slice.
484
498
///
485
499
/// For images: RGB format with length `nx * ny * 3`
486
500
/// For audio: PCM F32 format with length `n_samples * 4`
487
- #[ must_use] pub fn data ( & self ) -> & [ u8 ] {
501
+ #[ must_use]
502
+ pub fn data ( & self ) -> & [ u8 ] {
488
503
let ptr = unsafe { llama_cpp_sys_2:: mtmd_bitmap_get_data ( self . bitmap . as_ptr ( ) ) } ;
489
504
let len = unsafe { llama_cpp_sys_2:: mtmd_bitmap_get_n_bytes ( self . bitmap . as_ptr ( ) ) } ;
490
505
unsafe { slice:: from_raw_parts ( ptr, len) }
491
506
}
492
507
493
508
/// Check if this bitmap contains audio data (vs image data).
494
- #[ must_use] pub fn is_audio ( & self ) -> bool {
509
+ #[ must_use]
510
+ pub fn is_audio ( & self ) -> bool {
495
511
unsafe { llama_cpp_sys_2:: mtmd_bitmap_is_audio ( self . bitmap . as_ptr ( ) ) }
496
512
}
497
513
498
514
/// Get the bitmap's optional ID string.
499
515
///
500
516
/// Bitmap ID is useful for KV cache tracking and can e.g. be calculated
501
517
/// based on a hash of the bitmap data.
502
- #[ must_use] pub fn id ( & self ) -> Option < String > {
518
+ #[ must_use]
519
+ pub fn id ( & self ) -> Option < String > {
503
520
let ptr = unsafe { llama_cpp_sys_2:: mtmd_bitmap_get_id ( self . bitmap . as_ptr ( ) ) } ;
504
521
if ptr. is_null ( ) {
505
522
None
506
523
} else {
507
- unsafe { CStr :: from_ptr ( ptr) }
524
+ let id = unsafe { CStr :: from_ptr ( ptr) }
508
525
. to_string_lossy ( )
509
- . into_owned ( )
510
- . into ( )
526
+ . into_owned ( ) ;
527
+ Some ( id )
511
528
}
512
529
}
513
530
@@ -580,24 +597,28 @@ impl MtmdInputChunks {
580
597
/// assert_eq!(chunks.len(), 0);
581
598
/// assert!(chunks.is_empty());
582
599
/// ```
583
- #[ must_use] pub fn new ( ) -> Self {
600
+ #[ must_use]
601
+ pub fn new ( ) -> Self {
584
602
let chunks = unsafe { llama_cpp_sys_2:: mtmd_input_chunks_init ( ) } ;
585
603
let chunks = NonNull :: new ( chunks) . unwrap ( ) ;
586
604
Self { chunks }
587
605
}
588
606
589
607
/// Get the number of chunks
590
- #[ must_use] pub fn len ( & self ) -> usize {
608
+ #[ must_use]
609
+ pub fn len ( & self ) -> usize {
591
610
unsafe { llama_cpp_sys_2:: mtmd_input_chunks_size ( self . chunks . as_ptr ( ) ) }
592
611
}
593
612
594
613
/// Check if chunks collection is empty
595
- #[ must_use] pub fn is_empty ( & self ) -> bool {
614
+ #[ must_use]
615
+ pub fn is_empty ( & self ) -> bool {
596
616
self . len ( ) == 0
597
617
}
598
618
599
619
/// Get a chunk by index
600
- #[ must_use] pub fn get ( & self , index : usize ) -> Option < MtmdInputChunk > {
620
+ #[ must_use]
621
+ pub fn get ( & self , index : usize ) -> Option < MtmdInputChunk > {
601
622
if index >= self . len ( ) {
602
623
return None ;
603
624
}
@@ -619,15 +640,17 @@ impl MtmdInputChunks {
619
640
/// Get total number of tokens across all chunks.
620
641
///
621
642
/// This is useful for keeping track of KV cache size.
622
- #[ must_use] pub fn total_tokens ( & self ) -> usize {
643
+ #[ must_use]
644
+ pub fn total_tokens ( & self ) -> usize {
623
645
unsafe { llama_cpp_sys_2:: mtmd_helper_get_n_tokens ( self . chunks . as_ptr ( ) ) }
624
646
}
625
647
626
648
/// Get total position count across all chunks.
627
649
///
628
650
/// This is useful to keep track of `n_past`. Normally `n_pos` equals `n_tokens`,
629
651
/// but for M-RoPE it is different.
630
- #[ must_use] pub fn total_positions ( & self ) -> i32 {
652
+ #[ must_use]
653
+ pub fn total_positions ( & self ) -> i32 {
631
654
unsafe { llama_cpp_sys_2:: mtmd_helper_get_n_pos ( self . chunks . as_ptr ( ) ) }
632
655
}
633
656
@@ -709,7 +732,8 @@ pub struct MtmdInputChunk {
709
732
710
733
impl MtmdInputChunk {
711
734
/// Get the type of this chunk
712
- #[ must_use] pub fn chunk_type ( & self ) -> MtmdInputChunkType {
735
+ #[ must_use]
736
+ pub fn chunk_type ( & self ) -> MtmdInputChunkType {
713
737
let chunk_type = unsafe { llama_cpp_sys_2:: mtmd_input_chunk_get_type ( self . chunk . as_ptr ( ) ) } ;
714
738
MtmdInputChunkType :: from ( chunk_type)
715
739
}
@@ -721,7 +745,8 @@ impl MtmdInputChunk {
721
745
/// # Returns
722
746
///
723
747
/// Returns `Some(&[LlamaToken])` for text chunks, `None` otherwise.
724
- #[ must_use] pub fn text_tokens ( & self ) -> Option < & [ LlamaToken ] > {
748
+ #[ must_use]
749
+ pub fn text_tokens ( & self ) -> Option < & [ LlamaToken ] > {
725
750
if self . chunk_type ( ) != MtmdInputChunkType :: Text {
726
751
return None ;
727
752
}
@@ -744,21 +769,24 @@ impl MtmdInputChunk {
744
769
}
745
770
746
771
/// Get the number of tokens in this chunk
747
- #[ must_use] pub fn n_tokens ( & self ) -> usize {
772
+ #[ must_use]
773
+ pub fn n_tokens ( & self ) -> usize {
748
774
unsafe { llama_cpp_sys_2:: mtmd_input_chunk_get_n_tokens ( self . chunk . as_ptr ( ) ) }
749
775
}
750
776
751
777
/// Get the number of positions in this chunk.
752
778
///
753
779
/// Returns the number of temporal positions (always 1 for M-RoPE, `n_tokens` otherwise).
754
- #[ must_use] pub fn n_positions ( & self ) -> i32 {
780
+ #[ must_use]
781
+ pub fn n_positions ( & self ) -> i32 {
755
782
unsafe { llama_cpp_sys_2:: mtmd_input_chunk_get_n_pos ( self . chunk . as_ptr ( ) ) }
756
783
}
757
784
758
785
/// Get chunk ID if available.
759
786
///
760
787
/// Returns `None` for text chunks, may return an ID for image/audio chunks.
761
- #[ must_use] pub fn id ( & self ) -> Option < String > {
788
+ #[ must_use]
789
+ pub fn id ( & self ) -> Option < String > {
762
790
let ptr = unsafe { llama_cpp_sys_2:: mtmd_input_chunk_get_id ( self . chunk . as_ptr ( ) ) } ;
763
791
if ptr. is_null ( ) {
764
792
None
@@ -819,7 +847,8 @@ impl Drop for MtmdInputChunk {
819
847
/// let text = format!("Describe this image: {}", marker);
820
848
/// assert!(text.contains(marker));
821
849
/// ```
822
- #[ must_use] pub fn mtmd_default_marker ( ) -> & ' static str {
850
+ #[ must_use]
851
+ pub fn mtmd_default_marker ( ) -> & ' static str {
823
852
unsafe {
824
853
let c_str = llama_cpp_sys_2:: mtmd_default_marker ( ) ;
825
854
CStr :: from_ptr ( c_str) . to_str ( ) . unwrap_or ( "<__media__>" )
@@ -877,6 +906,9 @@ pub enum MtmdTokenizeError {
877
906
/// Image preprocessing error occurred
878
907
#[ error( "Image preprocessing error" ) ]
879
908
ImagePreprocessingError ,
909
+ /// Text contains characters that cannot be converted to C string
910
+ #[ error( "Failed to create CString from text: {0}" ) ]
911
+ CStringError ( #[ from] std:: ffi:: NulError ) ,
880
912
/// Unknown error occurred during tokenization
881
913
#[ error( "Unknown error: {0}" ) ]
882
914
UnknownError ( i32 ) ,
0 commit comments