Skip to content

Commit 99d1c41

Browse files
authored
Merge pull request #316 from SilasMarvin/silas-fix-apply-chat-template
2 parents 584077b + 2cb4498 commit 99d1c41

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

llama-cpp-2/src/model.rs

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,11 @@ impl LlamaModel {
113113
/// # Errors
114114
///
115115
/// See [`TokenToStringError`] for more information.
116-
pub fn token_to_str(&self, token: LlamaToken, special: Special) -> Result<String, TokenToStringError> {
116+
pub fn token_to_str(
117+
&self,
118+
token: LlamaToken,
119+
special: Special,
120+
) -> Result<String, TokenToStringError> {
117121
self.token_to_str_with_size(token, 32, special)
118122
}
119123

@@ -122,7 +126,11 @@ impl LlamaModel {
122126
/// # Errors
123127
///
124128
/// See [`TokenToStringError`] for more information.
125-
pub fn token_to_bytes(&self, token: LlamaToken, special: Special) -> Result<Vec<u8>, TokenToStringError> {
129+
pub fn token_to_bytes(
130+
&self,
131+
token: LlamaToken,
132+
special: Special,
133+
) -> Result<Vec<u8>, TokenToStringError> {
126134
self.token_to_bytes_with_size(token, 32, special)
127135
}
128136

@@ -131,9 +139,17 @@ impl LlamaModel {
131139
/// # Errors
132140
///
133141
/// See [`TokenToStringError`] for more information.
134-
pub fn tokens_to_str(&self, tokens: &[LlamaToken], special: Special) -> Result<String, TokenToStringError> {
142+
pub fn tokens_to_str(
143+
&self,
144+
tokens: &[LlamaToken],
145+
special: Special,
146+
) -> Result<String, TokenToStringError> {
135147
let mut builder = String::with_capacity(tokens.len() * 4);
136-
for str in tokens.iter().copied().map(|t| self.token_to_str(t, special)) {
148+
for str in tokens
149+
.iter()
150+
.copied()
151+
.map(|t| self.token_to_str(t, special))
152+
{
137153
builder += &str?;
138154
}
139155
Ok(builder)
@@ -451,12 +467,14 @@ impl LlamaModel {
451467
content: c.content.as_ptr(),
452468
})
453469
.collect();
470+
454471
// Set the tmpl pointer
455472
let tmpl = tmpl.map(CString::new);
456-
let tmpl_ptr = match tmpl {
457-
Some(str) => str?.as_ptr(),
473+
let tmpl_ptr = match &tmpl {
474+
Some(str) => str.as_ref().map_err(|e| e.clone())?.as_ptr(),
458475
None => std::ptr::null(),
459476
};
477+
460478
let formatted_chat = unsafe {
461479
let res = llama_cpp_sys_2::llama_chat_apply_template(
462480
self.model.as_ptr(),

0 commit comments

Comments
 (0)