Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions examples/simple/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ use anyhow::{anyhow, bail, Context, Result};
use clap::Parser;
use hf_hub::api::sync::ApiBuilder;
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::{ggml_time_us, send_logs_to_tracing, LogOptions};
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::LlamaModel;
use llama_cpp_2::model::{AddBos, Special};
use llama_cpp_2::sampling::LlamaSampler;
use llama_cpp_2::{ggml_time_us, send_logs_to_tracing, LogOptions};

use std::ffi::CString;
use std::io::Write;
Expand Down Expand Up @@ -67,11 +67,7 @@ struct Args {
help = "size of the prompt context (default: loaded from themodel)"
)]
ctx_size: Option<NonZeroU32>,
#[arg(
short = 'v',
long,
help = "enable verbose llama.cpp logs",
)]
#[arg(short = 'v', long, help = "enable verbose llama.cpp logs")]
verbose: bool,
}

Expand Down
9 changes: 6 additions & 3 deletions llama-cpp-2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ pub enum LLamaCppError {
/// There was an error while getting the chat template from a model.
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum ChatTemplateError {
/// the buffer was too small.
#[error("The buffer was too small. However, a buffer size of {0} would be just large enough.")]
BuffSizeError(usize),
/// gguf has no chat template
#[error("the model has no meta val - returned code {0}")]
MissingTemplate(i32),
Expand All @@ -80,6 +77,12 @@ pub enum ChatTemplateError {
Utf8Error(#[from] std::str::Utf8Error),
}

enum InternalChatTemplateError {
Permanent(ChatTemplateError),
/// the buffer was too small.
RetryWithLargerBuffer(usize),
}

/// Failed to Load context
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum LlamaContextLoadError {
Expand Down
3 changes: 2 additions & 1 deletion llama-cpp-2/src/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ impl State {
} else {
let level = self
.previous_level
.load(std::sync::atomic::Ordering::Acquire) as llama_cpp_sys_2::ggml_log_level;
.load(std::sync::atomic::Ordering::Acquire)
as llama_cpp_sys_2::ggml_log_level;
tracing::warn!(
inferred_level = level,
text = text,
Expand Down
150 changes: 118 additions & 32 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
//! A safe wrapper around `llama_model`.
use std::ffi::{c_char, CString};
use std::ffi::{c_char, CStr, CString};
use std::num::NonZeroU16;
use std::os::raw::c_int;
use std::path::Path;
use std::ptr::NonNull;
use std::str::{FromStr, Utf8Error};

use crate::context::params::LlamaContextParams;
use crate::context::LlamaContext;
Expand All @@ -12,8 +13,9 @@ use crate::model::params::LlamaModelParams;
use crate::token::LlamaToken;
use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
use crate::{
ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
LlamaModelLoadError, NewLlamaChatMessageError, StringToTokenError, TokenToStringError,
ApplyChatTemplateError, ChatTemplateError, InternalChatTemplateError, LlamaContextLoadError,
LlamaLoraAdapterInitError, LlamaModelLoadError, NewLlamaChatMessageError, StringToTokenError,
TokenToStringError,
};

pub mod params;
Expand All @@ -34,6 +36,42 @@ pub struct LlamaLoraAdapter {
pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
}

/// A performance-friendly wrapper around [LlamaModel::get_chat_template] which is then
/// fed into [LlamaModel::apply_chat_template] to convert a list of messages into an LLM
/// prompt. Internally the template is stored as a CString to avoid round-trip conversions
/// within the FFI.
#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)]
pub struct LlamaChatTemplate(CString);

impl LlamaChatTemplate {
/// Create a new template from a string. This can either be the name of a llama.cpp [chat template](https://github.com/ggerganov/llama.cpp/blob/8a8c4ceb6050bd9392609114ca56ae6d26f5b8f5/src/llama-chat.cpp#L27-L61)
/// like "chatml" or "llama3" or an actual Jinja template for llama.cpp to interpret.
pub fn new(template: &str) -> Result<Self, std::ffi::NulError> {
Ok(Self(CString::from_str(template)?))
}

/// Accesses the template as a c string reference.
pub fn as_c_str(&self) -> &CStr {
&self.0
}

/// Attempts to convert the CString into a Rust str reference.
pub fn to_str(&self) -> Result<&str, Utf8Error> {
self.0.to_str()
}

/// Convenience method to create an owned String.
pub fn to_string(&self) -> Result<String, Utf8Error> {
self.to_str().map(str::to_string)
}
}

impl std::fmt::Debug for LlamaChatTemplate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}

/// A Safe wrapper around `llama_chat_message`
#[derive(Debug, Eq, PartialEq, Clone)]
pub struct LlamaChatMessage {
Expand Down Expand Up @@ -408,41 +446,84 @@ impl LlamaModel {
unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
}

/// Get chat template from model.
///
/// # Errors
///
/// * If the model has no chat template
/// * If the chat template is not a valid [`CString`].
#[allow(clippy::missing_panics_doc)] // we statically know this will not panic as
pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
fn get_chat_template_impl(
&self,
capacity: usize,
) -> Result<LlamaChatTemplate, InternalChatTemplateError> {
// longest known template is about 1200 bytes from llama.cpp
let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
let chat_ptr = chat_temp.into_raw();
let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
// TODO: Once MaybeUninit support is better, this can be converted to use that instead of dummy initializing such a large array.
let mut chat_temp = vec![b'*' as u8; capacity];
let chat_name =
CStr::from_bytes_with_nul(b"tokenizer.chat_template\0").expect("should have null byte");

let ret = unsafe {
llama_cpp_sys_2::llama_model_meta_val_str(
self.model.as_ptr(),
chat_name.as_ptr(),
chat_ptr,
buf_size,
chat_temp.as_mut_ptr() as *mut c_char,
chat_temp.len(),
)
};

if ret < 0 {
return Err(ChatTemplateError::MissingTemplate(ret));
return Err(InternalChatTemplateError::Permanent(
ChatTemplateError::MissingTemplate(ret),
));
}

let template_c = unsafe { CString::from_raw(chat_ptr) };
let template = template_c.to_str()?;
let returned_len = ret as usize;

let ret: usize = ret.try_into().unwrap();
if template.len() < ret {
return Err(ChatTemplateError::BuffSizeError(ret + 1));
if ret as usize >= capacity {
// >= is important because if the returned length is equal to capacity, it means we're missing a trailing null
// since the returned length doesn't count the trailing null.
return Err(InternalChatTemplateError::RetryWithLargerBuffer(
returned_len,
));
}

Ok(template.to_owned())
assert_eq!(
chat_temp.get(returned_len),
Some(&0),
"should end with null byte"
);

chat_temp.resize(returned_len + 1, 0);

Ok(LlamaChatTemplate(unsafe {
CString::from_vec_with_nul_unchecked(chat_temp)
}))
}

/// Get chat template from model. If this fails, you may either want to fail to chat or pick the
/// specific shortcode that llama.cpp supports templates it has baked-in directly into its codebase
/// as fallbacks when the model doesn't contain. NOTE: If you don't specify a chat template, then
/// it uses chatml by default which is unlikely to actually be the correct template for your model
/// and you'll get weird results back.
///
/// You supply this into [Self::apply_chat_template] to get back a string with the appropriate template
/// substitution applied to convert a list of messages into a prompt the LLM can use to complete
/// the chat.
///
/// # Errors
///
/// * If the model has no chat template
/// * If the chat template is not a valid [`CString`].
#[allow(clippy::missing_panics_doc)] // we statically know this will not panic as
pub fn get_chat_template(&self) -> Result<LlamaChatTemplate, ChatTemplateError> {
// Typical chat templates are quite small. Let's start with a small allocation likely to succeed.
// Ideally the performance of this would be negligible but uninitialized arrays in Rust are currently
// still not well supported so we end up initializing the chat template buffer twice. One idea might
// be to use a very small value here that will likely fail (like 0 or 1) and then use that to initialize.
// Not sure which approach is the most optimal but in practice this should work well.
match self.get_chat_template_impl(200) {
Ok(t) => Ok(t),
Err(InternalChatTemplateError::Permanent(e)) => Err(e),
Err(InternalChatTemplateError::RetryWithLargerBuffer(actual_len)) => match self.get_chat_template_impl(actual_len + 1) {
Ok(t) => Ok(t),
Err(InternalChatTemplateError::Permanent(e)) => Err(e),
Err(InternalChatTemplateError::RetryWithLargerBuffer(unexpected_len)) => panic!("Was told that the template length was {actual_len} but now it's {unexpected_len}"),
}
}
}

/// Loads a model from a file.
Expand Down Expand Up @@ -526,15 +607,25 @@ impl LlamaModel {
/// Apply the models chat template to some messages.
/// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
///
/// `tmpl` of None means to use the default template provided by llama.cpp for the model
/// Unlike the llama.cpp apply_chat_template which just randomly uses the ChatML template when given
/// a null pointer for the template, this requires an explicit template to be specified. If you want to
/// use "chatml", then just do `LlamaChatTemplate::new("chatml")` or any other model name or template
/// string.
///
/// Use [Self::get_chat_template] to retrieve the template baked into the model (this is the preferred
/// mechanism as using the wrong chat template can result in really unexpected responses from the LLM).
///
/// You probably want to set `add_ass` to true so that the generated template string ends with a the
/// opening tag of the assistant. If you fail to leave a hanging chat tag, the model will likely generate
/// one into the output and the output may also have unexpected output aside from that.
///
/// # Errors
/// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
#[tracing::instrument(skip_all)]
pub fn apply_chat_template(
&self,
tmpl: Option<String>,
chat: Vec<LlamaChatMessage>,
tmpl: &LlamaChatTemplate,
chat: &[LlamaChatMessage],
add_ass: bool,
) -> Result<String, ApplyChatTemplateError> {
// Buffer is twice the length of messages per their recommendation
Expand All @@ -552,12 +643,7 @@ impl LlamaModel {
})
.collect();

// Set the tmpl pointer
let tmpl = tmpl.map(CString::new);
let tmpl_ptr = match &tmpl {
Some(str) => str.as_ref().map_err(Clone::clone)?.as_ptr(),
None => std::ptr::null(),
};
let tmpl_ptr = tmpl.0.as_ptr();

let res = unsafe {
llama_cpp_sys_2::llama_chat_apply_template(
Expand Down
Loading