@@ -11,8 +11,8 @@ use crate::model::params::LlamaModelParams;
1111use crate :: token:: LlamaToken ;
1212use crate :: token_type:: LlamaTokenType ;
1313use crate :: {
14- ChatTemplateError , LlamaContextLoadError , LlamaModelLoadError , StringToTokenError ,
15- TokenToStringError ,
14+ ApplyChatTemplateError , ChatTemplateError , LlamaContextLoadError , LlamaModelLoadError ,
15+ NewLlamaChatMessageError , StringToTokenError , TokenToStringError ,
1616} ;
1717
1818pub mod params;
@@ -25,6 +25,23 @@ pub struct LlamaModel {
2525 pub ( crate ) model : NonNull < llama_cpp_sys_2:: llama_model > ,
2626}
2727
28+ /// A Safe wrapper around `llama_chat_message`
29+ #[ derive( Debug , Eq , PartialEq , Clone ) ]
30+ pub struct LlamaChatMessage {
31+ role : CString ,
32+ content : CString ,
33+ }
34+
35+ impl LlamaChatMessage {
36+ /// Create a new `LlamaChatMessage`
37+ pub fn new ( role : String , content : String ) -> Result < Self , NewLlamaChatMessageError > {
38+ Ok ( Self {
39+ role : CString :: new ( role) ?,
40+ content : CString :: new ( content) ?,
41+ } )
42+ }
43+ }
44+
2845/// How to determine if we should prepend a bos token to tokens
2946#[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
3047pub enum AddBos {
@@ -312,17 +329,16 @@ impl LlamaModel {
312329 /// Get chat template from model.
313330 ///
314331 /// # Errors
315- ///
332+ ///
316333 /// * If the model has no chat template
317334 /// * If the chat template is not a valid [`CString`].
318335 #[ allow( clippy:: missing_panics_doc) ] // we statically know this will not panic as
319336 pub fn get_chat_template ( & self , buf_size : usize ) -> Result < String , ChatTemplateError > {
320-
321337 // longest known template is about 1200 bytes from llama.cpp
322338 let chat_temp = CString :: new ( vec ! [ b'*' ; buf_size] ) . expect ( "no null" ) ;
323339 let chat_ptr = chat_temp. into_raw ( ) ;
324340 let chat_name = CString :: new ( "tokenizer.chat_template" ) . expect ( "no null bytes" ) ;
325-
341+
326342 let chat_template: String = unsafe {
327343 let ret = llama_cpp_sys_2:: llama_model_meta_val_str (
328344 self . model . as_ptr ( ) ,
@@ -337,7 +353,7 @@ impl LlamaModel {
337353 debug_assert_eq ! ( usize :: try_from( ret) . unwrap( ) , template. len( ) , "llama.cpp guarantees that the returned int {ret} is the length of the string {} but that was not the case" , template. len( ) ) ;
338354 template
339355 } ;
340-
356+
341357 Ok ( chat_template)
342358 }
343359
@@ -388,6 +404,60 @@ impl LlamaModel {
388404
389405 Ok ( LlamaContext :: new ( self , context, params. embeddings ( ) ) )
390406 }
407+
408+ /// Apply the models chat template to some messages.
409+ /// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
410+ ///
411+ /// `tmpl` of None means to use the default template provided by llama.cpp for the model
412+ ///
413+ /// # Errors
414+ /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
415+ #[ tracing:: instrument( skip_all) ]
416+ pub fn apply_chat_template (
417+ & self ,
418+ tmpl : Option < String > ,
419+ chat : Vec < LlamaChatMessage > ,
420+ add_ass : bool ,
421+ ) -> Result < String , ApplyChatTemplateError > {
422+ // Buffer is twice the length of messages per their recommendation
423+ let message_length = chat. iter ( ) . fold ( 0 , |acc, c| {
424+ acc + c. role . to_bytes ( ) . len ( ) + c. content . to_bytes ( ) . len ( )
425+ } ) ;
426+ let mut buff: Vec < i8 > = vec ! [ 0_i8 ; message_length * 2 ] ;
427+
428+ // Build our llama_cpp_sys_2 chat messages
429+ let chat: Vec < llama_cpp_sys_2:: llama_chat_message > = chat
430+ . iter ( )
431+ . map ( |c| llama_cpp_sys_2:: llama_chat_message {
432+ role : c. role . as_ptr ( ) ,
433+ content : c. content . as_ptr ( ) ,
434+ } )
435+ . collect ( ) ;
436+ // Set the tmpl pointer
437+ let tmpl = tmpl. map ( CString :: new) ;
438+ let tmpl_ptr = match tmpl {
439+ Some ( str) => str?. as_ptr ( ) ,
440+ None => std:: ptr:: null ( ) ,
441+ } ;
442+ let formatted_chat = unsafe {
443+ let res = llama_cpp_sys_2:: llama_chat_apply_template (
444+ self . model . as_ptr ( ) ,
445+ tmpl_ptr,
446+ chat. as_ptr ( ) ,
447+ chat. len ( ) ,
448+ add_ass,
449+ buff. as_mut_ptr ( ) . cast :: < std:: os:: raw:: c_char > ( ) ,
450+ buff. len ( ) as i32 ,
451+ ) ;
452+ // A buffer twice the size should be sufficient for all models, if this is not the case for a new model, we can increase it
453+ // The error message informs the user to contact a maintainer
454+ if res > buff. len ( ) as i32 {
455+ return Err ( ApplyChatTemplateError :: BuffSizeError ) ;
456+ }
457+ String :: from_utf8 ( buff. iter ( ) . filter ( |c| * * c > 0 ) . map ( |& c| c as u8 ) . collect ( ) )
458+ } ?;
459+ Ok ( formatted_chat)
460+ }
391461}
392462
393463impl Drop for LlamaModel {
0 commit comments