@@ -11,8 +11,8 @@ use crate::model::params::LlamaModelParams;
11
11
use crate :: token:: LlamaToken ;
12
12
use crate :: token_type:: LlamaTokenType ;
13
13
use crate :: {
14
- ChatTemplateError , LlamaContextLoadError , LlamaModelLoadError , StringToTokenError ,
15
- TokenToStringError ,
14
+ ApplyChatTemplateError , ChatTemplateError , LlamaContextLoadError , LlamaModelLoadError ,
15
+ NewLlamaChatMessageError , StringToTokenError , TokenToStringError ,
16
16
} ;
17
17
18
18
pub mod params;
@@ -25,6 +25,23 @@ pub struct LlamaModel {
25
25
pub ( crate ) model : NonNull < llama_cpp_sys_2:: llama_model > ,
26
26
}
27
27
28
+ /// A Safe wrapper around `llama_chat_message`
29
+ #[ derive( Debug , Eq , PartialEq , Clone ) ]
30
+ pub struct LlamaChatMessage {
31
+ role : CString ,
32
+ content : CString ,
33
+ }
34
+
35
+ impl LlamaChatMessage {
36
+ /// Create a new `LlamaChatMessage`
37
+ pub fn new ( role : String , content : String ) -> Result < Self , NewLlamaChatMessageError > {
38
+ Ok ( Self {
39
+ role : CString :: new ( role) ?,
40
+ content : CString :: new ( content) ?,
41
+ } )
42
+ }
43
+ }
44
+
28
45
/// How to determine if we should prepend a bos token to tokens
29
46
#[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
30
47
pub enum AddBos {
@@ -312,17 +329,16 @@ impl LlamaModel {
312
329
/// Get chat template from model.
313
330
///
314
331
/// # Errors
315
- ///
332
+ ///
316
333
/// * If the model has no chat template
317
334
/// * If the chat template is not a valid [`CString`].
318
335
#[ allow( clippy:: missing_panics_doc) ] // we statically know this will not panic as
319
336
pub fn get_chat_template ( & self , buf_size : usize ) -> Result < String , ChatTemplateError > {
320
-
321
337
// longest known template is about 1200 bytes from llama.cpp
322
338
let chat_temp = CString :: new ( vec ! [ b'*' ; buf_size] ) . expect ( "no null" ) ;
323
339
let chat_ptr = chat_temp. into_raw ( ) ;
324
340
let chat_name = CString :: new ( "tokenizer.chat_template" ) . expect ( "no null bytes" ) ;
325
-
341
+
326
342
let chat_template: String = unsafe {
327
343
let ret = llama_cpp_sys_2:: llama_model_meta_val_str (
328
344
self . model . as_ptr ( ) ,
@@ -337,7 +353,7 @@ impl LlamaModel {
337
353
debug_assert_eq ! ( usize :: try_from( ret) . unwrap( ) , template. len( ) , "llama.cpp guarantees that the returned int {ret} is the length of the string {} but that was not the case" , template. len( ) ) ;
338
354
template
339
355
} ;
340
-
356
+
341
357
Ok ( chat_template)
342
358
}
343
359
@@ -388,6 +404,60 @@ impl LlamaModel {
388
404
389
405
Ok ( LlamaContext :: new ( self , context, params. embeddings ( ) ) )
390
406
}
407
+
408
+ /// Apply the models chat template to some messages.
409
+ /// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
410
+ ///
411
+ /// `tmpl` of None means to use the default template provided by llama.cpp for the model
412
+ ///
413
+ /// # Errors
414
+ /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
415
+ #[ tracing:: instrument( skip_all) ]
416
+ pub fn apply_chat_template (
417
+ & self ,
418
+ tmpl : Option < String > ,
419
+ chat : Vec < LlamaChatMessage > ,
420
+ add_ass : bool ,
421
+ ) -> Result < String , ApplyChatTemplateError > {
422
+ // Buffer is twice the length of messages per their recommendation
423
+ let message_length = chat. iter ( ) . fold ( 0 , |acc, c| {
424
+ acc + c. role . to_bytes ( ) . len ( ) + c. content . to_bytes ( ) . len ( )
425
+ } ) ;
426
+ let mut buff: Vec < i8 > = vec ! [ 0_i8 ; message_length * 2 ] ;
427
+
428
+ // Build our llama_cpp_sys_2 chat messages
429
+ let chat: Vec < llama_cpp_sys_2:: llama_chat_message > = chat
430
+ . iter ( )
431
+ . map ( |c| llama_cpp_sys_2:: llama_chat_message {
432
+ role : c. role . as_ptr ( ) ,
433
+ content : c. content . as_ptr ( ) ,
434
+ } )
435
+ . collect ( ) ;
436
+ // Set the tmpl pointer
437
+ let tmpl = tmpl. map ( CString :: new) ;
438
+ let tmpl_ptr = match tmpl {
439
+ Some ( str) => str?. as_ptr ( ) ,
440
+ None => std:: ptr:: null ( ) ,
441
+ } ;
442
+ let formatted_chat = unsafe {
443
+ let res = llama_cpp_sys_2:: llama_chat_apply_template (
444
+ self . model . as_ptr ( ) ,
445
+ tmpl_ptr,
446
+ chat. as_ptr ( ) ,
447
+ chat. len ( ) ,
448
+ add_ass,
449
+ buff. as_mut_ptr ( ) . cast :: < std:: os:: raw:: c_char > ( ) ,
450
+ buff. len ( ) as i32 ,
451
+ ) ;
452
+ // A buffer twice the size should be sufficient for all models, if this is not the case for a new model, we can increase it
453
+ // The error message informs the user to contact a maintainer
454
+ if res > buff. len ( ) as i32 {
455
+ return Err ( ApplyChatTemplateError :: BuffSizeError ) ;
456
+ }
457
+ String :: from_utf8 ( buff. iter ( ) . filter ( |c| * * c > 0 ) . map ( |& c| c as u8 ) . collect ( ) )
458
+ } ?;
459
+ Ok ( formatted_chat)
460
+ }
391
461
}
392
462
393
463
impl Drop for LlamaModel {
0 commit comments