@@ -11,8 +11,8 @@ use crate::model::params::LlamaModelParams;
11
11
use crate :: token:: LlamaToken ;
12
12
use crate :: token_type:: LlamaTokenType ;
13
13
use crate :: {
14
- ChatTemplateError , LlamaContextLoadError , LlamaModelLoadError , StringToTokenError ,
15
- TokenToStringError ,
14
+ ApplyChatTemplateError , ChatTemplateError , LlamaContextLoadError , LlamaModelLoadError ,
15
+ NewLlamaChatMessageError , StringToTokenError , TokenToStringError ,
16
16
} ;
17
17
18
18
pub mod params;
@@ -25,6 +25,23 @@ pub struct LlamaModel {
25
25
pub ( crate ) model : NonNull < llama_cpp_sys_2:: llama_model > ,
26
26
}
27
27
28
+ /// A Safe wrapper around `llama_chat_message`
29
+ #[ derive( Debug ) ]
30
+ pub struct LlamaChatMessage {
31
+ role : CString ,
32
+ content : CString ,
33
+ }
34
+
35
+ impl LlamaChatMessage {
36
+ /// Create a new `LlamaChatMessage`
37
+ pub fn new ( role : String , content : String ) -> Result < Self , NewLlamaChatMessageError > {
38
+ Ok ( Self {
39
+ role : CString :: new ( role) ?,
40
+ content : CString :: new ( content) ?,
41
+ } )
42
+ }
43
+ }
44
+
28
45
/// How to determine if we should prepend a bos token to tokens
29
46
#[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
30
47
pub enum AddBos {
@@ -312,17 +329,16 @@ impl LlamaModel {
312
329
/// Get chat template from model.
313
330
///
314
331
/// # Errors
315
- ///
332
+ ///
316
333
/// * If the model has no chat template
317
334
/// * If the chat template is not a valid [`CString`].
318
335
#[ allow( clippy:: missing_panics_doc) ] // we statically know this will not panic as
319
336
pub fn get_chat_template ( & self , buf_size : usize ) -> Result < String , ChatTemplateError > {
320
-
321
337
// longest known template is about 1200 bytes from llama.cpp
322
338
let chat_temp = CString :: new ( vec ! [ b'*' ; buf_size] ) . expect ( "no null" ) ;
323
339
let chat_ptr = chat_temp. into_raw ( ) ;
324
340
let chat_name = CString :: new ( "tokenizer.chat_template" ) . expect ( "no null bytes" ) ;
325
-
341
+
326
342
let chat_template: String = unsafe {
327
343
let ret = llama_cpp_sys_2:: llama_model_meta_val_str (
328
344
self . model . as_ptr ( ) ,
@@ -337,7 +353,7 @@ impl LlamaModel {
337
353
debug_assert_eq ! ( usize :: try_from( ret) . unwrap( ) , template. len( ) , "llama.cpp guarantees that the returned int {ret} is the length of the string {} but that was not the case" , template. len( ) ) ;
338
354
template
339
355
} ;
340
-
356
+
341
357
Ok ( chat_template)
342
358
}
343
359
@@ -388,6 +404,58 @@ impl LlamaModel {
388
404
389
405
Ok ( LlamaContext :: new ( self , context, params. embeddings ( ) ) )
390
406
}
407
+
408
+ /// Apply the models chat template to some messages.
409
+ /// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
410
+ ///
411
+ /// # Errors
412
+ /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
413
+ #[ tracing:: instrument( skip_all) ]
414
+ pub fn apply_chat_template (
415
+ & self ,
416
+ tmpl : Option < String > ,
417
+ chat : Vec < LlamaChatMessage > ,
418
+ add_ass : bool ,
419
+ ) -> Result < String , ApplyChatTemplateError > {
420
+ // Buffer is twice the length of messages per their recommendation
421
+ let message_length = chat. iter ( ) . fold ( 0 , |acc, c| {
422
+ acc + c. role . to_bytes ( ) . len ( ) + c. content . to_bytes ( ) . len ( )
423
+ } ) ;
424
+ let mut buff: Vec < i8 > = vec ! [ 0_i8 ; message_length * 2 ] ;
425
+ // Build our llama_cpp_sys_2 chat messages
426
+ let chat: Vec < llama_cpp_sys_2:: llama_chat_message > = chat
427
+ . iter ( )
428
+ . map ( |c| llama_cpp_sys_2:: llama_chat_message {
429
+ role : c. role . as_ptr ( ) ,
430
+ content : c. content . as_ptr ( ) ,
431
+ } )
432
+ . collect ( ) ;
433
+ // Set the tmpl pointer
434
+ let tmpl = tmpl. map ( |v| CString :: new ( v) ) ;
435
+ eprintln ! ( "TEMPLATE AGAIN: {:?}" , tmpl) ;
436
+ let tmpl_ptr = match tmpl {
437
+ Some ( str) => str?. as_ptr ( ) ,
438
+ None => std:: ptr:: null ( ) ,
439
+ } ;
440
+ let formatted_chat = unsafe {
441
+ let res = llama_cpp_sys_2:: llama_chat_apply_template (
442
+ self . model . as_ptr ( ) ,
443
+ tmpl_ptr,
444
+ chat. as_ptr ( ) ,
445
+ chat. len ( ) ,
446
+ add_ass,
447
+ buff. as_mut_ptr ( ) ,
448
+ buff. len ( ) as i32 ,
449
+ ) ;
450
+ // This should never happen
451
+ if res > buff. len ( ) as i32 {
452
+ return Err ( ApplyChatTemplateError :: BuffSizeError ) ;
453
+ }
454
+ println ! ( "BUFF: {:?}" , buff) ;
455
+ String :: from_utf8 ( buff. iter ( ) . filter ( |c| * * c > 0 ) . map ( |& c| c as u8 ) . collect ( ) )
456
+ } ;
457
+ Ok ( formatted_chat?)
458
+ }
391
459
}
392
460
393
461
impl Drop for LlamaModel {
0 commit comments