Skip to content

Commit 1fb7ec3

Browse files
committed
Added apply_chat_template to model
1 parent 89f73ec commit 1fb7ec3

File tree

2 files changed

+96
-6
lines changed

2 files changed

+96
-6
lines changed

llama-cpp-2/src/lib.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,28 @@ pub enum StringToTokenError {
207207
CIntConversionError(#[from] std::num::TryFromIntError),
208208
}
209209

210+
/// Failed to apply model chat template.
211+
#[derive(Debug, thiserror::Error)]
212+
pub enum NewLlamaChatMessageError {
213+
/// the string contained a null byte and thus could not be converted to a c string.
214+
#[error("{0}")]
215+
NulError(#[from] NulError),
216+
}
217+
218+
/// Failed to apply model chat template.
219+
#[derive(Debug, thiserror::Error)]
220+
pub enum ApplyChatTemplateError {
221+
/// the buffer was too small.
222+
#[error("The buffer was too small. Please contact a maintainer")]
223+
BuffSizeError,
224+
/// the string contained a null byte and thus could not be converted to a c string.
225+
#[error("{0}")]
226+
NulError(#[from] NulError),
227+
/// the string could not be converted to utf8.
228+
#[error("{0}")]
229+
FromUtf8Error(#[from] FromUtf8Error),
230+
}
231+
210232
/// Get the time in microseconds according to ggml
211233
///
212234
/// ```

llama-cpp-2/src/model.rs

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ use crate::model::params::LlamaModelParams;
1111
use crate::token::LlamaToken;
1212
use crate::token_type::LlamaTokenType;
1313
use crate::{
14-
ChatTemplateError, LlamaContextLoadError, LlamaModelLoadError, StringToTokenError,
15-
TokenToStringError,
14+
ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaModelLoadError,
15+
NewLlamaChatMessageError, StringToTokenError, TokenToStringError,
1616
};
1717

1818
pub mod params;
@@ -25,6 +25,23 @@ pub struct LlamaModel {
2525
pub(crate) model: NonNull<llama_cpp_sys_2::llama_model>,
2626
}
2727

28+
/// A Safe wrapper around `llama_chat_message`
29+
#[derive(Debug)]
30+
pub struct LlamaChatMessage {
31+
role: CString,
32+
content: CString,
33+
}
34+
35+
impl LlamaChatMessage {
36+
/// Create a new `LlamaChatMessage`
37+
pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
38+
Ok(Self {
39+
role: CString::new(role)?,
40+
content: CString::new(content)?,
41+
})
42+
}
43+
}
44+
2845
/// How to determine if we should prepend a bos token to tokens
2946
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3047
pub enum AddBos {
@@ -312,17 +329,16 @@ impl LlamaModel {
312329
/// Get chat template from model.
313330
///
314331
/// # Errors
315-
///
332+
///
316333
/// * If the model has no chat template
317334
/// * If the chat template is not a valid [`CString`].
318335
#[allow(clippy::missing_panics_doc)] // we statically know this will not panic as
319336
pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
320-
321337
// longest known template is about 1200 bytes from llama.cpp
322338
let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
323339
let chat_ptr = chat_temp.into_raw();
324340
let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
325-
341+
326342
let chat_template: String = unsafe {
327343
let ret = llama_cpp_sys_2::llama_model_meta_val_str(
328344
self.model.as_ptr(),
@@ -337,7 +353,7 @@ impl LlamaModel {
337353
debug_assert_eq!(usize::try_from(ret).unwrap(), template.len(), "llama.cpp guarantees that the returned int {ret} is the length of the string {} but that was not the case", template.len());
338354
template
339355
};
340-
356+
341357
Ok(chat_template)
342358
}
343359

@@ -388,6 +404,58 @@ impl LlamaModel {
388404

389405
Ok(LlamaContext::new(self, context, params.embeddings()))
390406
}
407+
408+
/// Apply the models chat template to some messages.
409+
/// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
410+
///
411+
/// # Errors
412+
/// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
413+
#[tracing::instrument(skip_all)]
414+
pub fn apply_chat_template(
415+
&self,
416+
tmpl: Option<String>,
417+
chat: Vec<LlamaChatMessage>,
418+
add_ass: bool,
419+
) -> Result<String, ApplyChatTemplateError> {
420+
// Buffer is twice the length of messages per their recommendation
421+
let message_length = chat.iter().fold(0, |acc, c| {
422+
acc + c.role.to_bytes().len() + c.content.to_bytes().len()
423+
});
424+
let mut buff: Vec<i8> = vec![0_i8; message_length * 2];
425+
// Build our llama_cpp_sys_2 chat messages
426+
let chat: Vec<llama_cpp_sys_2::llama_chat_message> = chat
427+
.iter()
428+
.map(|c| llama_cpp_sys_2::llama_chat_message {
429+
role: c.role.as_ptr(),
430+
content: c.content.as_ptr(),
431+
})
432+
.collect();
433+
// Set the tmpl pointer
434+
let tmpl = tmpl.map(|v| CString::new(v));
435+
eprintln!("TEMPLATE AGAIN: {:?}", tmpl);
436+
let tmpl_ptr = match tmpl {
437+
Some(str) => str?.as_ptr(),
438+
None => std::ptr::null(),
439+
};
440+
let formatted_chat = unsafe {
441+
let res = llama_cpp_sys_2::llama_chat_apply_template(
442+
self.model.as_ptr(),
443+
tmpl_ptr,
444+
chat.as_ptr(),
445+
chat.len(),
446+
add_ass,
447+
buff.as_mut_ptr(),
448+
buff.len() as i32,
449+
);
450+
// This should never happen
451+
if res > buff.len() as i32 {
452+
return Err(ApplyChatTemplateError::BuffSizeError);
453+
}
454+
println!("BUFF: {:?}", buff);
455+
String::from_utf8(buff.iter().filter(|c| **c > 0).map(|&c| c as u8).collect())
456+
};
457+
Ok(formatted_chat?)
458+
}
391459
}
392460

393461
impl Drop for LlamaModel {

0 commit comments

Comments
 (0)