@@ -11,8 +11,8 @@ use crate::model::params::LlamaModelParams;
1111use crate :: token:: LlamaToken ;
1212use crate :: token_type:: LlamaTokenType ;
1313use crate :: {
14- ChatTemplateError , LlamaContextLoadError , LlamaModelLoadError , StringToTokenError ,
15- TokenToStringError ,
14+ ApplyChatTemplateError , ChatTemplateError , LlamaContextLoadError , LlamaModelLoadError ,
15+ NewLlamaChatMessageError , StringToTokenError , TokenToStringError ,
1616} ;
1717
1818pub mod params;
@@ -25,6 +25,23 @@ pub struct LlamaModel {
2525 pub ( crate ) model : NonNull < llama_cpp_sys_2:: llama_model > ,
2626}
2727
28+ /// A Safe wrapper around `llama_chat_message`
29+ #[ derive( Debug ) ]
30+ pub struct LlamaChatMessage {
31+ role : CString ,
32+ content : CString ,
33+ }
34+
35+ impl LlamaChatMessage {
36+ /// Create a new `LlamaChatMessage`
37+ pub fn new ( role : String , content : String ) -> Result < Self , NewLlamaChatMessageError > {
38+ Ok ( Self {
39+ role : CString :: new ( role) ?,
40+ content : CString :: new ( content) ?,
41+ } )
42+ }
43+ }
44+
2845/// How to determine if we should prepend a bos token to tokens
2946#[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
3047pub enum AddBos {
@@ -312,17 +329,16 @@ impl LlamaModel {
312329 /// Get chat template from model.
313330 ///
314331 /// # Errors
315- ///
332+ ///
316333 /// * If the model has no chat template
317334 /// * If the chat template is not a valid [`CString`].
318335 #[ allow( clippy:: missing_panics_doc) ] // we statically know this will not panic as
319336 pub fn get_chat_template ( & self , buf_size : usize ) -> Result < String , ChatTemplateError > {
320-
321337 // longest known template is about 1200 bytes from llama.cpp
322338 let chat_temp = CString :: new ( vec ! [ b'*' ; buf_size] ) . expect ( "no null" ) ;
323339 let chat_ptr = chat_temp. into_raw ( ) ;
324340 let chat_name = CString :: new ( "tokenizer.chat_template" ) . expect ( "no null bytes" ) ;
325-
341+
326342 let chat_template: String = unsafe {
327343 let ret = llama_cpp_sys_2:: llama_model_meta_val_str (
328344 self . model . as_ptr ( ) ,
@@ -337,7 +353,7 @@ impl LlamaModel {
337353 debug_assert_eq ! ( usize :: try_from( ret) . unwrap( ) , template. len( ) , "llama.cpp guarantees that the returned int {ret} is the length of the string {} but that was not the case" , template. len( ) ) ;
338354 template
339355 } ;
340-
356+
341357 Ok ( chat_template)
342358 }
343359
@@ -388,6 +404,58 @@ impl LlamaModel {
388404
389405 Ok ( LlamaContext :: new ( self , context, params. embeddings ( ) ) )
390406 }
407+
408+ /// Apply the models chat template to some messages.
409+ /// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
410+ ///
411+ /// # Errors
412+ /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
413+ #[ tracing:: instrument( skip_all) ]
414+ pub fn apply_chat_template (
415+ & self ,
416+ tmpl : Option < String > ,
417+ chat : Vec < LlamaChatMessage > ,
418+ add_ass : bool ,
419+ ) -> Result < String , ApplyChatTemplateError > {
420+ // Buffer is twice the length of messages per their recommendation
421+ let message_length = chat. iter ( ) . fold ( 0 , |acc, c| {
422+ acc + c. role . to_bytes ( ) . len ( ) + c. content . to_bytes ( ) . len ( )
423+ } ) ;
424+ let mut buff: Vec < i8 > = vec ! [ 0_i8 ; message_length * 2 ] ;
425+ // Build our llama_cpp_sys_2 chat messages
426+ let chat: Vec < llama_cpp_sys_2:: llama_chat_message > = chat
427+ . iter ( )
428+ . map ( |c| llama_cpp_sys_2:: llama_chat_message {
429+ role : c. role . as_ptr ( ) ,
430+ content : c. content . as_ptr ( ) ,
431+ } )
432+ . collect ( ) ;
433+ // Set the tmpl pointer
434+ let tmpl = tmpl. map ( |v| CString :: new ( v) ) ;
435+ eprintln ! ( "TEMPLATE AGAIN: {:?}" , tmpl) ;
436+ let tmpl_ptr = match tmpl {
437+ Some ( str) => str?. as_ptr ( ) ,
438+ None => std:: ptr:: null ( ) ,
439+ } ;
440+ let formatted_chat = unsafe {
441+ let res = llama_cpp_sys_2:: llama_chat_apply_template (
442+ self . model . as_ptr ( ) ,
443+ tmpl_ptr,
444+ chat. as_ptr ( ) ,
445+ chat. len ( ) ,
446+ add_ass,
447+ buff. as_mut_ptr ( ) ,
448+ buff. len ( ) as i32 ,
449+ ) ;
450+ // This should never happen
451+ if res > buff. len ( ) as i32 {
452+ return Err ( ApplyChatTemplateError :: BuffSizeError ) ;
453+ }
454+ println ! ( "BUFF: {:?}" , buff) ;
455+ String :: from_utf8 ( buff. iter ( ) . filter ( |c| * * c > 0 ) . map ( |& c| c as u8 ) . collect ( ) )
456+ } ;
457+ Ok ( formatted_chat?)
458+ }
391459}
392460
393461impl Drop for LlamaModel {
0 commit comments