Skip to content

Commit 636da79

Browse files
authored
Merge pull request #127 from SilasMarvin/silas-apply-chat-template
Added Apply Chat Template
2 parents 89f73ec + 6f9fa32 commit 636da79

File tree

2 files changed

+98
-6
lines changed

2 files changed

+98
-6
lines changed

llama-cpp-2/src/lib.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,28 @@ pub enum StringToTokenError {
207207
CIntConversionError(#[from] std::num::TryFromIntError),
208208
}
209209

210+
/// Failed to apply model chat template.
211+
#[derive(Debug, thiserror::Error)]
212+
pub enum NewLlamaChatMessageError {
213+
/// the string contained a null byte and thus could not be converted to a c string.
214+
#[error("{0}")]
215+
NulError(#[from] NulError),
216+
}
217+
218+
/// Failed to apply model chat template.
219+
#[derive(Debug, thiserror::Error)]
220+
pub enum ApplyChatTemplateError {
221+
/// the buffer was too small.
222+
#[error("The buffer was too small. Please contact a maintainer and we will update it.")]
223+
BuffSizeError,
224+
/// the string contained a null byte and thus could not be converted to a c string.
225+
#[error("{0}")]
226+
NulError(#[from] NulError),
227+
/// the string could not be converted to utf8.
228+
#[error("{0}")]
229+
FromUtf8Error(#[from] FromUtf8Error),
230+
}
231+
210232
/// Get the time in microseconds according to ggml
211233
///
212234
/// ```

llama-cpp-2/src/model.rs

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ use crate::model::params::LlamaModelParams;
1111
use crate::token::LlamaToken;
1212
use crate::token_type::LlamaTokenType;
1313
use crate::{
14-
ChatTemplateError, LlamaContextLoadError, LlamaModelLoadError, StringToTokenError,
15-
TokenToStringError,
14+
ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaModelLoadError,
15+
NewLlamaChatMessageError, StringToTokenError, TokenToStringError,
1616
};
1717

1818
pub mod params;
@@ -25,6 +25,23 @@ pub struct LlamaModel {
2525
pub(crate) model: NonNull<llama_cpp_sys_2::llama_model>,
2626
}
2727

28+
/// A Safe wrapper around `llama_chat_message`
29+
#[derive(Debug, Eq, PartialEq, Clone)]
30+
pub struct LlamaChatMessage {
31+
role: CString,
32+
content: CString,
33+
}
34+
35+
impl LlamaChatMessage {
36+
/// Create a new `LlamaChatMessage`
37+
pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
38+
Ok(Self {
39+
role: CString::new(role)?,
40+
content: CString::new(content)?,
41+
})
42+
}
43+
}
44+
2845
/// How to determine if we should prepend a bos token to tokens
2946
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3047
pub enum AddBos {
@@ -312,17 +329,16 @@ impl LlamaModel {
312329
/// Get chat template from model.
313330
///
314331
/// # Errors
315-
///
332+
///
316333
/// * If the model has no chat template
317334
/// * If the chat template is not a valid [`CString`].
318335
#[allow(clippy::missing_panics_doc)] // we statically know this will not panic as
319336
pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
320-
321337
// longest known template is about 1200 bytes from llama.cpp
322338
let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
323339
let chat_ptr = chat_temp.into_raw();
324340
let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
325-
341+
326342
let chat_template: String = unsafe {
327343
let ret = llama_cpp_sys_2::llama_model_meta_val_str(
328344
self.model.as_ptr(),
@@ -337,7 +353,7 @@ impl LlamaModel {
337353
debug_assert_eq!(usize::try_from(ret).unwrap(), template.len(), "llama.cpp guarantees that the returned int {ret} is the length of the string {} but that was not the case", template.len());
338354
template
339355
};
340-
356+
341357
Ok(chat_template)
342358
}
343359

@@ -388,6 +404,60 @@ impl LlamaModel {
388404

389405
Ok(LlamaContext::new(self, context, params.embeddings()))
390406
}
407+
408+
/// Apply the models chat template to some messages.
409+
/// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
410+
///
411+
/// `tmpl` of None means to use the default template provided by llama.cpp for the model
412+
///
413+
/// # Errors
414+
/// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
415+
#[tracing::instrument(skip_all)]
416+
pub fn apply_chat_template(
417+
&self,
418+
tmpl: Option<String>,
419+
chat: Vec<LlamaChatMessage>,
420+
add_ass: bool,
421+
) -> Result<String, ApplyChatTemplateError> {
422+
// Buffer is twice the length of messages per their recommendation
423+
let message_length = chat.iter().fold(0, |acc, c| {
424+
acc + c.role.to_bytes().len() + c.content.to_bytes().len()
425+
});
426+
let mut buff: Vec<i8> = vec![0_i8; message_length * 2];
427+
428+
// Build our llama_cpp_sys_2 chat messages
429+
let chat: Vec<llama_cpp_sys_2::llama_chat_message> = chat
430+
.iter()
431+
.map(|c| llama_cpp_sys_2::llama_chat_message {
432+
role: c.role.as_ptr(),
433+
content: c.content.as_ptr(),
434+
})
435+
.collect();
436+
// Set the tmpl pointer
437+
let tmpl = tmpl.map(CString::new);
438+
let tmpl_ptr = match tmpl {
439+
Some(str) => str?.as_ptr(),
440+
None => std::ptr::null(),
441+
};
442+
let formatted_chat = unsafe {
443+
let res = llama_cpp_sys_2::llama_chat_apply_template(
444+
self.model.as_ptr(),
445+
tmpl_ptr,
446+
chat.as_ptr(),
447+
chat.len(),
448+
add_ass,
449+
buff.as_mut_ptr().cast::<std::os::raw::c_char>(),
450+
buff.len() as i32,
451+
);
452+
// A buffer twice the size should be sufficient for all models, if this is not the case for a new model, we can increase it
453+
// The error message informs the user to contact a maintainer
454+
if res > buff.len() as i32 {
455+
return Err(ApplyChatTemplateError::BuffSizeError);
456+
}
457+
String::from_utf8(buff.iter().filter(|c| **c > 0).map(|&c| c as u8).collect())
458+
}?;
459+
Ok(formatted_chat)
460+
}
391461
}
392462

393463
impl Drop for LlamaModel {

0 commit comments

Comments
 (0)