Skip to content

Commit 97c38ad

Browse files
committed
updated llama.cpp
1 parent 8cc7022 commit 97c38ad

File tree

3 files changed

+29
-22
lines changed

3 files changed

+29
-22
lines changed

llama-cpp-2/src/context/kv_cache.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,11 @@ impl<'a> KVCacheView<'a> {
238238
unsafe {
239239
std::slice::from_raw_parts(
240240
self.view.cells_sequences,
241-
usize::try_from(self.view.n_cells * self.view.n_max_seq)
241+
usize::try_from(self.view.n_cells * self.view.n_seq_max)
242242
.expect("failed to fit n_cells * n_max_seq into usize"),
243243
)
244244
}
245-
.chunks(usize::try_from(self.view.n_max_seq).expect("failed to fit n_max_seq into usize"))
245+
.chunks(usize::try_from(self.view.n_seq_max).expect("failed to fit n_max_seq into usize"))
246246
}
247247
}
248248

llama-cpp-2/src/lib.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,15 @@ pub enum LLamaCppError {
6161
EmbeddingError(#[from] EmbeddingsError),
6262
}
6363

64+
/// There was an error while getting the chat template from a model.
6465
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
6566
pub enum ChatTemplateError {
6667
/// gguf has no chat template
67-
#[error("model has no chat template in gguf")]
68-
NullReturn,
68+
#[error("the model has no meta val - returned code {0}")]
69+
MissingTemplate(i32),
70+
/// The chat template was not valid utf8.
71+
#[error(transparent)]
72+
Utf8Error(#[from] std::str::Utf8Error),
6973
}
7074

7175
/// Failed to Load context

llama-cpp-2/src/model.rs

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -277,32 +277,35 @@ impl LlamaModel {
277277
unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
278278
}
279279

280-
/// get chat template from model
281-
/// let chat_template = model.get_chat_template()?;
280+
/// Get chat template from model.
282281
///
283-
pub fn get_chat_template(&self) -> Result<String, ChatTemplateError> {
282+
/// # Errors
283+
///
284+
/// * If the model has no chat template
285+
/// * If the chat template is not a valid [`CString`].
286+
#[allow(clippy::missing_panics_doc)] // we statically know this will not panic as
287+
pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
288+
289+
// longest known template is about 1200 bytes from llama.cpp
290+
let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
291+
let chat_ptr = chat_temp.into_raw();
292+
let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
293+
284294
let chat_template: String = unsafe {
285-
// longest known template is about 1200 bytes from llama.cpp
286-
let chat_temp = match CString::new(Vec::<u8>::with_capacity(2048)) {
287-
Ok(c) => c,
288-
Err(_) => return Err(ChatTemplateError::NullReturn),
289-
};
290-
let chat_ptr = chat_temp.into_raw();
291-
let chat_name = match CString::new("tokenizer.chat_template") {
292-
Ok(c) => c,
293-
Err(_) => return Err(ChatTemplateError::NullReturn),
294-
};
295-
llama_cpp_sys_2::llama_model_meta_val_str(
295+
let ret = llama_cpp_sys_2::llama_model_meta_val_str(
296296
self.model.as_ptr(),
297297
chat_name.as_ptr(),
298298
chat_ptr,
299-
250,
299+
buf_size,
300300
);
301-
match CString::from_raw(chat_ptr).to_str() {
302-
Ok(s) => s.to_string(),
303-
Err(_) => return Err(ChatTemplateError::NullReturn),
301+
if ret < 0 {
302+
return Err(ChatTemplateError::MissingTemplate(ret));
304303
}
304+
let template = CString::from_raw(chat_ptr).to_str()?.to_string();
305+
debug_assert_eq!(usize::try_from(ret).unwrap(), template.len(), "llama.cpp guarantees that the returned int {ret} is the length of the string {} but that was not the case", template.len());
306+
template
305307
};
308+
306309
Ok(chat_template)
307310
}
308311

0 commit comments

Comments
 (0)