Skip to content

Commit 63791cb

Browse files
authored
Merge pull request #699 from nobodywho-ooo/chat-template-by-name
Reimplement get_chat_template to use `llama_model_chat_template`
2 parents d75f231 + 593257e commit 63791cb

File tree

3 files changed

+37
-84
lines changed

3 files changed

+37
-84
lines changed

llama-cpp-2/src/lib.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,19 @@ pub enum LLamaCppError {
6969
/// There was an error while getting the chat template from a model.
7070
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
7171
pub enum ChatTemplateError {
72-
/// gguf has no chat template
73-
#[error("the model has no meta val - returned code {0}")]
74-
MissingTemplate(i32),
72+
/// gguf has no chat template (by that name)
73+
#[error("chat template not found - returned null pointer")]
74+
MissingTemplate,
75+
76+
/// chat template contained a null byte
77+
#[error("null byte in string {0}")]
78+
NullError(#[from] NulError),
79+
7580
/// The chat template was not valid utf8.
7681
#[error(transparent)]
7782
Utf8Error(#[from] std::str::Utf8Error),
7883
}
7984

80-
enum InternalChatTemplateError {
81-
Permanent(ChatTemplateError),
82-
/// the buffer was too small.
83-
RetryWithLargerBuffer(usize),
84-
}
85-
8685
/// Failed to Load context
8786
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
8887
pub enum LlamaContextLoadError {

llama-cpp-2/src/model.rs

Lines changed: 28 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@ use crate::model::params::LlamaModelParams;
1313
use crate::token::LlamaToken;
1414
use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
1515
use crate::{
16-
ApplyChatTemplateError, ChatTemplateError, InternalChatTemplateError, LlamaContextLoadError,
17-
LlamaLoraAdapterInitError, LlamaModelLoadError, NewLlamaChatMessageError, StringToTokenError,
18-
TokenToStringError,
16+
ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
17+
LlamaModelLoadError, NewLlamaChatMessageError, StringToTokenError, TokenToStringError,
1918
};
2019

2120
pub mod params;
@@ -36,7 +35,7 @@ pub struct LlamaLoraAdapter {
3635
pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
3736
}
3837

39-
/// A performance-friendly wrapper around [LlamaModel::get_chat_template] which is then
38+
/// A performance-friendly wrapper around [LlamaModel::chat_template] which is then
4039
/// fed into [LlamaModel::apply_chat_template] to convert a list of messages into an LLM
4140
/// prompt. Internally the template is stored as a CString to avoid round-trip conversions
4241
/// within the FFI.
@@ -506,83 +505,38 @@ impl LlamaModel {
506505
}
507506
}
508507

509-
fn get_chat_template_impl(
510-
&self,
511-
capacity: usize,
512-
) -> Result<LlamaChatTemplate, InternalChatTemplateError> {
513-
// longest known template is about 1200 bytes from llama.cpp
514-
// TODO: Once MaybeUninit support is better, this can be converted to use that instead of dummy initializing such a large array.
515-
let mut chat_temp = vec![b'*' as u8; capacity];
516-
let chat_name =
517-
CStr::from_bytes_with_nul(b"tokenizer.chat_template\0").expect("should have null byte");
518-
519-
let ret = unsafe {
520-
llama_cpp_sys_2::llama_model_meta_val_str(
521-
self.model.as_ptr(),
522-
chat_name.as_ptr(),
523-
chat_temp.as_mut_ptr() as *mut c_char,
524-
chat_temp.len(),
525-
)
526-
};
527-
528-
if ret < 0 {
529-
return Err(InternalChatTemplateError::Permanent(
530-
ChatTemplateError::MissingTemplate(ret),
531-
));
532-
}
533-
534-
let returned_len = ret as usize;
535-
536-
if ret as usize >= capacity {
537-
// >= is important because if the returned length is equal to capacity, it means we're missing a trailing null
538-
// since the returned length doesn't count the trailing null.
539-
return Err(InternalChatTemplateError::RetryWithLargerBuffer(
540-
returned_len,
541-
));
542-
}
543-
544-
assert_eq!(
545-
chat_temp.get(returned_len),
546-
Some(&0),
547-
"should end with null byte"
548-
);
549-
550-
chat_temp.resize(returned_len + 1, 0);
551-
552-
Ok(LlamaChatTemplate(unsafe {
553-
CString::from_vec_with_nul_unchecked(chat_temp)
554-
}))
555-
}
556-
557-
/// Get chat template from model. If this fails, you may either want to fail to chat or pick the
558-
/// specific shortcode that llama.cpp supports templates it has baked-in directly into its codebase
559-
/// as fallbacks when the model doesn't contain. NOTE: If you don't specify a chat template, then
560-
/// it uses chatml by default which is unlikely to actually be the correct template for your model
561-
/// and you'll get weird results back.
508+
/// Get chat template from model by name. If the name parameter is None, the default chat template will be returned.
562509
///
563510
/// You supply this into [Self::apply_chat_template] to get back a string with the appropriate template
564511
/// substitution applied to convert a list of messages into a prompt the LLM can use to complete
565512
/// the chat.
566513
///
514+
/// You could also use an external jinja parser, like [minijinja](https://github.com/mitsuhiko/minijinja),
515+
/// to parse jinja templates not supported by the llama.cpp template engine.
516+
///
567517
/// # Errors
568518
///
569-
/// * If the model has no chat template
519+
/// * If the model has no chat template by that name
570520
/// * If the chat template is not a valid [`CString`].
571-
#[allow(clippy::missing_panics_doc)] // we statically know this will not panic as
572-
pub fn get_chat_template(&self) -> Result<LlamaChatTemplate, ChatTemplateError> {
573-
// Typical chat templates are quite small. Let's start with a small allocation likely to succeed.
574-
// Ideally the performance of this would be negligible but uninitialized arrays in Rust are currently
575-
// still not well supported so we end up initializing the chat template buffer twice. One idea might
576-
// be to use a very small value here that will likely fail (like 0 or 1) and then use that to initialize.
577-
// Not sure which approach is the most optimal but in practice this should work well.
578-
match self.get_chat_template_impl(200) {
579-
Ok(t) => Ok(t),
580-
Err(InternalChatTemplateError::Permanent(e)) => Err(e),
581-
Err(InternalChatTemplateError::RetryWithLargerBuffer(actual_len)) => match self.get_chat_template_impl(actual_len + 1) {
582-
Ok(t) => Ok(t),
583-
Err(InternalChatTemplateError::Permanent(e)) => Err(e),
584-
Err(InternalChatTemplateError::RetryWithLargerBuffer(unexpected_len)) => panic!("Was told that the template length was {actual_len} but now it's {unexpected_len}"),
585-
}
521+
pub fn chat_template(
522+
&self,
523+
name: Option<&str>,
524+
) -> Result<LlamaChatTemplate, ChatTemplateError> {
525+
let name_cstr = name.map(CString::new);
526+
let name_ptr = match name_cstr {
527+
Some(Ok(name)) => name.as_ptr(),
528+
_ => std::ptr::null(),
529+
};
530+
let result =
531+
unsafe { llama_cpp_sys_2::llama_model_chat_template(self.model.as_ptr(), name_ptr) };
532+
533+
// Convert result to Rust String if not null
534+
if result.is_null() {
535+
Err(ChatTemplateError::MissingTemplate)
536+
} else {
537+
let chat_template_cstr = unsafe { CStr::from_ptr(result) };
538+
let chat_template = CString::new(chat_template_cstr.to_bytes())?;
539+
Ok(LlamaChatTemplate(chat_template))
586540
}
587541
}
588542

@@ -672,7 +626,7 @@ impl LlamaModel {
672626
/// use "chatml", then just do `LlamaChatTemplate::new("chatml")` or any other model name or template
673627
/// string.
674628
///
675-
/// Use [Self::get_chat_template] to retrieve the template baked into the model (this is the preferred
629+
/// Use [Self::chat_template] to retrieve the template baked into the model (this is the preferred
676630
/// mechanism as using the wrong chat template can result in really unexpected responses from the LLM).
677631
///
678632
/// You probably want to set `add_ass` to true so that the generated template string ends with a the

llama-cpp-sys-2/llama.cpp

0 commit comments

Comments
 (0)