Skip to content

Commit dc52d3b

Browse files
committed
updated to latest llama.cpp (seems to run llama-3)
1 parent 7309252 commit dc52d3b

File tree

3 files changed

+26
-9
lines changed

3 files changed

+26
-9
lines changed

llama-cpp-2/src/model.rs

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@ pub enum AddBos {
5151
Never,
5252
}
5353

54+
/// How to determine if we should tokenize special tokens
55+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56+
pub enum Special {
57+
/// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
58+
Tokenize,
59+
/// Treat special and/or control tokens as plaintext.
60+
Plaintext,
61+
}
62+
5463
unsafe impl Send for LlamaModel {}
5564

5665
unsafe impl Sync for LlamaModel {}
@@ -71,10 +80,11 @@ impl LlamaModel {
7180
/// Get all tokens in the model.
7281
pub fn tokens(
7382
&self,
83+
special: Special,
7484
) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
7585
(0..self.n_vocab())
7686
.map(LlamaToken::new)
77-
.map(|llama_token| (llama_token, self.token_to_str(llama_token)))
87+
.map(move |llama_token| (llama_token, self.token_to_str(llama_token, special)))
7888
}
7989

8090
/// Get the beginning of stream token.
@@ -103,8 +113,8 @@ impl LlamaModel {
103113
/// # Errors
104114
///
105115
/// See [`TokenToStringError`] for more information.
106-
pub fn token_to_str(&self, token: LlamaToken) -> Result<String, TokenToStringError> {
107-
self.token_to_str_with_size(token, 32)
116+
pub fn token_to_str(&self, token: LlamaToken, special: Special) -> Result<String, TokenToStringError> {
117+
self.token_to_str_with_size(token, 32, special)
108118
}
109119

110120
/// Convert single token to bytes.
@@ -121,9 +131,9 @@ impl LlamaModel {
121131
/// # Errors
122132
///
123133
/// See [`TokenToStringError`] for more information.
124-
pub fn tokens_to_str(&self, tokens: &[LlamaToken]) -> Result<String, TokenToStringError> {
134+
pub fn tokens_to_str(&self, tokens: &[LlamaToken], special: Special) -> Result<String, TokenToStringError> {
125135
let mut builder = String::with_capacity(tokens.len() * 4);
126-
for str in tokens.iter().copied().map(|t| self.token_to_str(t)) {
136+
for str in tokens.iter().copied().map(|t| self.token_to_str(t, special)) {
127137
builder += &str?;
128138
}
129139
Ok(builder)
@@ -236,6 +246,7 @@ impl LlamaModel {
236246
&self,
237247
token: LlamaToken,
238248
buffer_size: usize,
249+
special: Special,
239250
) -> Result<String, TokenToStringError> {
240251
let bytes = self.token_to_bytes_with_size(token, buffer_size)?;
241252
Ok(String::from_utf8(bytes)?)
@@ -264,6 +275,7 @@ impl LlamaModel {
264275
return Ok(String::from("\n").into_bytes());
265276
}
266277

278+
// unsure what to do with this in the face of the 'special' arg
267279
match self.token_type(token) {
268280
LlamaTokenType::Normal | LlamaTokenType::UserDefined => {}
269281
LlamaTokenType::Control => {
@@ -279,12 +291,17 @@ impl LlamaModel {
279291
}
280292
}
281293

294+
let special = match special {
295+
Special::Tokenize => true,
296+
Special::Plaintext => false,
297+
};
298+
282299
let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
283300
let len = string.as_bytes().len();
284301
let len = c_int::try_from(len).expect("length fits into c_int");
285302
let buf = string.into_raw();
286303
let size = unsafe {
287-
llama_cpp_sys_2::llama_token_to_piece(self.model.as_ptr(), token.0, buf, len)
304+
llama_cpp_sys_2::llama_token_to_piece(self.model.as_ptr(), token.0, buf, len, special)
288305
};
289306

290307
match size {

llama-cpp-sys-2/llama.cpp

simple/src/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use llama_cpp_2::llama_backend::LlamaBackend;
1515
use llama_cpp_2::llama_batch::LlamaBatch;
1616
use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
1717
use llama_cpp_2::model::params::LlamaModelParams;
18-
use llama_cpp_2::model::AddBos;
18+
use llama_cpp_2::model::{AddBos, Special};
1919
use llama_cpp_2::model::LlamaModel;
2020
use llama_cpp_2::token::data_array::LlamaTokenDataArray;
2121
use std::ffi::CString;
@@ -214,7 +214,7 @@ either reduce n_len or increase n_ctx"
214214
eprintln!();
215215

216216
for token in &tokens_list {
217-
eprint!("{}", model.token_to_str(*token)?);
217+
eprint!("{}", model.token_to_str(*token, Special::Tokenize)?);
218218
}
219219

220220
std::io::stderr().flush()?;

0 commit comments

Comments
 (0)