Skip to content

Commit 99d9563

Browse files
authored
Merge pull request #232 from jiabochao/fix-multi-byte-decoding
fix: multi-byte utf8 decoding error
2 parents a0eebde + f9bd213 commit 99d9563

File tree

5 files changed

+56
-5
lines changed

5 files changed

+56
-5
lines changed

Cargo.lock

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ bindgen = "0.69.4"
1919
cc = "1.0.90"
2020
anyhow = "1.0.81"
2121
clap = "4.5.4"
22+
encoding_rs = "0.8.33"
2223

2324
[workspace.lints.rust]
2425
missing_docs = { level = "warn" }

llama-cpp-2/src/model.rs

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,15 @@ impl LlamaModel {
9090
self.token_to_str_with_size(token, 32)
9191
}
9292

93+
/// Convert single token to bytes.
94+
///
95+
/// # Errors
96+
///
97+
/// See [`TokenToStringError`] for more information.
98+
pub fn token_to_bytes(&self, token: LlamaToken) -> Result<Vec<u8>, TokenToStringError> {
99+
self.token_to_bytes_with_size(token, 32)
100+
}
101+
93102
/// Convert a vector of tokens to a single string.
94103
///
95104
/// # Errors
@@ -211,22 +220,45 @@ impl LlamaModel {
211220
token: LlamaToken,
212221
buffer_size: usize,
213222
) -> Result<String, TokenToStringError> {
223+
let bytes = self.token_to_bytes_with_size(token, buffer_size)?;
224+
Ok(String::from_utf8(bytes)?)
225+
}
226+
227+
/// Convert a token to bytes with a specified buffer size.
228+
///
229+
/// Generally you should use [`LlamaModel::token_to_bytes`] instead as 8 bytes is enough for most words and
230+
/// the extra bytes do not really matter.
231+
///
232+
/// # Errors
233+
///
234+
/// - if the token type is unknown
235+
/// - the resultant token is larger than `buffer_size`.
236+
///
237+
/// # Panics
238+
///
239+
/// - if `buffer_size` does not fit into a [`c_int`].
240+
/// - if the returned size from llama-cpp does not fit into a [`usize`]. (this should never happen)
241+
pub fn token_to_bytes_with_size(
242+
&self,
243+
token: LlamaToken,
244+
buffer_size: usize,
245+
) -> Result<Vec<u8>, TokenToStringError> {
214246
if token == self.token_nl() {
215-
return Ok(String::from("\n"));
247+
return Ok(String::from("\n").into_bytes());
216248
}
217249

218250
match self.token_type(token) {
219251
LlamaTokenType::Normal | LlamaTokenType::UserDefined => {}
220252
LlamaTokenType::Control => {
221253
if token == self.token_bos() || token == self.token_eos() {
222-
return Ok(String::new());
254+
return Ok(Vec::new());
223255
}
224256
}
225257
LlamaTokenType::Unknown
226258
| LlamaTokenType::Undefined
227259
| LlamaTokenType::Byte
228260
| LlamaTokenType::Unused => {
229-
return Ok(String::new());
261+
return Ok(Vec::new());
230262
}
231263
}
232264

@@ -246,7 +278,7 @@ impl LlamaModel {
246278
let mut bytes = string.into_bytes();
247279
let len = usize::try_from(size).expect("size is positive and fits into usize");
248280
bytes.truncate(len);
249-
Ok(String::from_utf8(bytes)?)
281+
Ok(bytes)
250282
}
251283
}
252284
}

simple/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ llama-cpp-2 = { path = "../llama-cpp-2", version = "0.1.46" }
1010
hf-hub = { workspace = true }
1111
clap = { workspace = true , features = ["derive"] }
1212
anyhow = { workspace = true }
13+
encoding_rs = { workspace = true }
1314

1415
[features]
1516
cublas = ["llama-cpp-2/cublas"]

simple/src/main.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ either reduce n_len or increase n_ctx"
240240

241241
let t_main_start = ggml_time_us();
242242

243+
// The `Decoder`
244+
let mut decoder = encoding_rs::UTF_8.new_decoder();
245+
243246
while n_cur <= n_len {
244247
// sample the next token
245248
{
@@ -256,7 +259,11 @@ either reduce n_len or increase n_ctx"
256259
break;
257260
}
258261

259-
print!("{}", model.token_to_str(new_token_id)?);
262+
let output_bytes = model.token_to_bytes(new_token_id)?;
263+
// use `Decoder.decode_to_string()` to avoid the intermediate buffer
264+
let mut output_string = String::with_capacity(32);
265+
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
266+
print!("{}", output_string);
260267
std::io::stdout().flush()?;
261268

262269
batch.clear();

0 commit comments

Comments
 (0)