Skip to content

Commit 8447c4d

Browse files
authored
Merge pull request #194 from bruceunx/feat-chat-template
added feature get chat template from gguf model
2 parents fae0864 + c7967a7 commit 8447c4d

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

llama-cpp-2/src/lib.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ pub enum LLamaCppError {
4141
/// is idempotent.
4242
#[error("BackendAlreadyInitialized")]
4343
BackendAlreadyInitialized,
44+
/// There was an error while get the chat template from model.
45+
#[error("{0}")]
46+
ChatTemplateError(#[from] ChatTemplateError),
4447
/// There was an error while decoding a batch.
4548
#[error("{0}")]
4649
DecodeError(#[from] DecodeError),
@@ -58,6 +61,13 @@ pub enum LLamaCppError {
5861
EmbeddingError(#[from] EmbeddingsError),
5962
}
6063

64+
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
65+
pub enum ChatTemplateError {
66+
/// gguf has no chat template
67+
#[error("model has no chat template in gguf")]
68+
NullReturn,
69+
}
70+
6171
/// Failed to Load context
6272
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
6373
pub enum LlamaContextLoadError {

llama-cpp-2/src/model.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ use crate::llama_backend::LlamaBackend;
1010
use crate::model::params::LlamaModelParams;
1111
use crate::token::LlamaToken;
1212
use crate::token_type::LlamaTokenType;
13-
use crate::{LlamaContextLoadError, LlamaModelLoadError, StringToTokenError, TokenToStringError};
13+
use crate::{
14+
ChatTemplateError, LlamaContextLoadError, LlamaModelLoadError, StringToTokenError,
15+
TokenToStringError,
16+
};
1417

1518
pub mod params;
1619

@@ -274,6 +277,35 @@ impl LlamaModel {
274277
unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
275278
}
276279

280+
/// get chat template from model
281+
/// let chat_template = model.get_chat_template()?;
282+
///
283+
pub fn get_chat_template(&self) -> Result<String, ChatTemplateError> {
284+
let chat_template: String = unsafe {
285+
// longest known template is about 1200 bytes from llama.cpp
286+
let chat_temp = match CString::new(Vec::<u8>::with_capacity(2048)) {
287+
Ok(c) => c,
288+
Err(_) => return Err(ChatTemplateError::NullReturn),
289+
};
290+
let chat_ptr = chat_temp.into_raw();
291+
let chat_name = match CString::new("tokenizer.chat_template") {
292+
Ok(c) => c,
293+
Err(_) => return Err(ChatTemplateError::NullReturn),
294+
};
295+
llama_cpp_sys_2::llama_model_meta_val_str(
296+
self.model.as_ptr(),
297+
chat_name.as_ptr(),
298+
chat_ptr,
299+
250,
300+
);
301+
match CString::from_raw(chat_ptr).to_str() {
302+
Ok(s) => s.to_string(),
303+
Err(_) => return Err(ChatTemplateError::NullReturn),
304+
}
305+
};
306+
Ok(chat_template)
307+
}
308+
277309
/// loads a model from a file.
278310
///
279311
/// # Errors

0 commit comments

Comments
 (0)