@@ -26,7 +26,7 @@ pub struct LlamaModel {
26
26
}
27
27
28
28
/// A Safe wrapper around `llama_chat_message`
29
- #[ derive( Debug ) ]
29
+ #[ derive( Debug , Eq , PartialEq , Clone ) ]
30
30
pub struct LlamaChatMessage {
31
31
role : CString ,
32
32
content : CString ,
@@ -408,6 +408,8 @@ impl LlamaModel {
408
408
/// Apply the models chat template to some messages.
409
409
/// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
410
410
///
411
+ /// `tmpl` of None means to use the default template provided by llama.cpp for the model
412
+ ///
411
413
/// # Errors
412
414
/// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
413
415
#[ tracing:: instrument( skip_all) ]
@@ -431,7 +433,7 @@ impl LlamaModel {
431
433
} )
432
434
. collect ( ) ;
433
435
// Set the tmpl pointer
434
- let tmpl = tmpl. map ( |v| CString :: new ( v ) ) ;
436
+ let tmpl = tmpl. map ( CString :: new) ;
435
437
let tmpl_ptr = match tmpl {
436
438
Some ( str) => str?. as_ptr ( ) ,
437
439
None => std:: ptr:: null ( ) ,
@@ -446,13 +448,14 @@ impl LlamaModel {
446
448
buff. as_mut_ptr ( ) ,
447
449
buff. len ( ) as i32 ,
448
450
) ;
449
- // This should never happen
451
+ // A buffer twice the size should be sufficient for all models, if this is not the case for a new model, we can increase it
452
+ // The error message informs the user to contact a maintainer
450
453
if res > buff. len ( ) as i32 {
451
454
return Err ( ApplyChatTemplateError :: BuffSizeError ) ;
452
455
}
453
456
String :: from_utf8 ( buff. iter ( ) . filter ( |c| * * c > 0 ) . map ( |& c| c as u8 ) . collect ( ) )
454
- } ;
455
- Ok ( formatted_chat? )
457
+ } ? ;
458
+ Ok ( formatted_chat)
456
459
}
457
460
}
458
461
0 commit comments