Skip to content

Commit 26edb07

Browse files
committed
Update tokenizers library to 0.13.4
Includes changes to exploit the change in the tokenizers library to allow decoding from slices rather than taking ownership of vecs. Signed-off-by: Nick Hill <[email protected]>
1 parent ab351a8 commit 26edb07

File tree

4 files changed

+95
-30
lines changed

4 files changed

+95
-30
lines changed

Cargo.lock

Lines changed: 33 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

router/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ serde_json = "^1.0.103"
3434
# spin comes in via tonic->tokio-rustls->rustls->ring but this pins a specific old version 0.5.2 :(
3535
#spin = "=0.9.8"
3636
thiserror = "^1.0.43"
37-
tokenizers = "^0.13.3"
37+
tokenizers = "0.13.4"
3838
tokio = { version = "^1.29.1", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "fs"] }
3939
tracing = "^0.1.37"
4040
tracing-subscriber = { version = "0.3.16", features = ["json"] }

router/src/decoder.rs

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::mem::take;
2+
use std::slice::from_ref;
23
use tokenizers::DecoderWrapper::{BPE, ByteLevel, Metaspace, WordPiece, CTC, Sequence};
34
use tokenizers::{Error, Tokenizer};
45
use unicode_segmentation::UnicodeSegmentation;
@@ -25,15 +26,15 @@ impl Decoder {
2526
.expect("Tokenizer setup error").get_ids().first().unwrap();
2627
Decoder {
2728
single_tok_id: prefix_id,
28-
single_tok: tokenizer.decode(vec![prefix_id], false).unwrap(),
29+
single_tok: tokenizer.decode(from_ref(&prefix_id), false).unwrap(),
2930
tokenizer,
3031
seq2seq,
3132
eos_token_id,
3233
skip_special_toks,
3334
}
3435
}
3536

36-
fn decode_full(&self, ids: Vec<u32>) -> Result<String, InferError> {
37+
fn decode_full(&self, ids: &[u32]) -> Result<String, InferError> {
3738
self.tokenizer.decode(ids, self.skip_special_toks).map_err(Error::into)
3839
}
3940

@@ -48,7 +49,7 @@ impl Decoder {
4849
if (first && self.seq2seq) || (last && matches![decoder, Some(BPE(_))])
4950
|| matches![decoder, Some(ByteLevel(_) | CTC(_))] {
5051
// In these cases we don't need to do anything special for "continuation"
51-
let mut text = self.decode_full(ids)?;
52+
let mut text = self.decode_full(&*ids)?;
5253
text.truncate(text.trim_end_matches('�').len()); // Avoid add'l allocation
5354
return Ok(text)
5455
}
@@ -59,7 +60,7 @@ impl Decoder {
5960
// For these, the first token in the sequence is treated differently,
6061
// so we add and then strip a placeholder token.
6162
ids.insert(0, self.single_tok_id);
62-
let result = self.decode_full(ids)?;
63+
let result = self.decode_full(&*ids)?;
6364
let mut text = result.strip_prefix(&self.single_tok).ok_or_else(
6465
|| DetokenizationError("Unexpected".into())
6566
)?.to_string();
@@ -68,10 +69,53 @@ impl Decoder {
6869
},
6970
Some(BPE(_)) => {
7071
ids.push(self.single_tok_id);
71-
let result = self.decode_full(ids)?;
72-
Ok(result.strip_suffix(&self.single_tok)
73-
.ok_or_else(|| DetokenizationError("Unexpected".into()))
74-
?.to_string())
72+
let result = self.decode_full(&*ids)?;
73+
Ok(result.strip_suffix(&self.single_tok).ok_or_else(
74+
|| DetokenizationError("Unexpected".into())
75+
)?.to_string())
76+
},
77+
None => {
78+
// Just prepend a space
79+
Ok(format!(" {}", self.decode_full(&*ids)?))
80+
},
81+
Some(tok) => {
82+
Err(DetokenizationError(format!("Unsupported tokenizer type: {:?}", tok)))
83+
}
84+
}
85+
}
86+
87+
pub(crate) fn decode_ref(
88+
&self, ids: &[u32], first: bool, last: bool,
89+
) -> Result<String, InferError> {
90+
let decoder = self.tokenizer.get_decoder();
91+
if (first && self.seq2seq) || (last && matches![decoder, Some(BPE(_))])
92+
|| matches![decoder, Some(ByteLevel(_) | CTC(_))] {
93+
// In these cases we don't need to do anything special for "continuation"
94+
let mut text = self.decode_full(ids)?;
95+
text.truncate(text.trim_end_matches('�').len()); // Avoid add'l allocation
96+
return Ok(text)
97+
}
98+
// How we handle continuation depends on the specific decoder's behaviour,
99+
// see each one's implementation of decode_chain in the tokenizers library.
100+
match self.tokenizer.get_decoder() {
101+
Some(Metaspace(_) | WordPiece(_) | Sequence(_)) => {
102+
// For these, the first token in the sequence is treated differently,
103+
// so we add and then strip a placeholder token.
104+
let ids = [from_ref(&0), ids].concat();
105+
let result = self.decode_full(&*ids)?;
106+
let mut text = result.strip_prefix(&self.single_tok).ok_or_else(
107+
|| DetokenizationError("Unexpected".into())
108+
)?.to_string();
109+
text.truncate(text.trim_end_matches('�').len()); // Avoid add'l allocation
110+
Ok(text)
111+
},
112+
Some(BPE(_)) => {
113+
let ids = [ids, from_ref(&self.single_tok_id)].concat();
114+
// ids.push(self.single_tok_id);
115+
let result = self.decode_full(&*ids)?;
116+
Ok(result.strip_suffix(&self.single_tok).ok_or_else(
117+
|| DetokenizationError("Unexpected".into())
118+
)?.to_string())
75119
},
76120
None => {
77121
// Just prepend a space
@@ -158,7 +202,7 @@ pub(crate) struct IncrementalFirstDiffDecoder {
158202

159203
impl IncrementalDecoder for IncrementalFirstDiffDecoder {
160204
fn next(&mut self, token: u32, decoder: &Decoder) -> Result<String, InferError> {
161-
let text = decoder.decode(vec![token], self.first, false)?;
205+
let text = decoder.decode_ref(from_ref(&token), self.first, false)?;
162206
self.first = false;
163207
self.output += &text;
164208
Ok(text)
@@ -182,7 +226,7 @@ impl IncrementalDecoder for IncrementalLastDiffDecoder {
182226
fn next(&mut self, token: u32, decoder: &Decoder) -> Result<String, InferError> {
183227
let text = self.next_id.map_or_else(
184228
|| Ok(String::new()),
185-
|id| decoder.decode(vec![id], true, false)
229+
|ref id| decoder.decode_ref(from_ref(id), true, false)
186230
)?;
187231
self.next_id = Some(token);
188232
self.output += &text;
@@ -192,7 +236,7 @@ impl IncrementalDecoder for IncrementalLastDiffDecoder {
192236
fn flush(&mut self, decoder: &Decoder) -> Result<String, InferError> {
193237
let text = self.next_id.map_or_else(
194238
|| Ok(String::new()),
195-
|id| decoder.decode_full(vec![id])
239+
|ref id| decoder.decode_full(from_ref(id))
196240
)?;
197241
self.next_id = None;
198242
self.output += &text;
@@ -219,7 +263,7 @@ impl IncrementalDecoder for IncrementalDeDupDecoder {
219263
return Ok(String::new())
220264
}
221265
self.last_id = Some(token);
222-
let text = decoder.decode_full(vec![token])?;
266+
let text = decoder.decode_full(from_ref(&token))?;
223267
self.output += &text;
224268
Ok(text)
225269
}
@@ -257,11 +301,11 @@ impl IncrementalBLDecoder {
257301
impl IncrementalDecoder for IncrementalBLDecoder {
258302
fn next(&mut self, token: u32, decoder: &Decoder) -> Result<String, InferError> {
259303
self.id_buffer.push(token);
260-
let mut buffer = self.id_buffer.clone();
304+
let buffer = &*self.id_buffer;
261305
let text = if self.first_diff && !self.first {
262306
// Prepend placeholder token to avoid first-token differences
263-
buffer.insert(0, decoder.single_tok_id);
264-
let result = decoder.decode_full(buffer)?;
307+
let buffer = [from_ref(&decoder.single_tok_id), buffer].concat();
308+
let result = decoder.decode_full(&*buffer)?;
265309
result.strip_prefix(&decoder.single_tok).ok_or_else(
266310
|| DetokenizationError("Unexpected".into())
267311
)?.to_string()
@@ -291,7 +335,7 @@ impl IncrementalDecoder for IncrementalBLDecoder {
291335
}
292336
fn flush(&mut self, decoder: &Decoder) -> Result<String, InferError> {
293337
if !self.id_buffer.is_empty() {
294-
let last = decoder.decode_full(self.id_buffer.clone())?;
338+
let last = decoder.decode_full(&*self.id_buffer)?;
295339
let last = last.trim_end_matches('�');
296340
self.output += last;
297341
self.str_buffer.push_str(last);

router/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ fn main() -> Result<(), std::io::Error> {
8282
);
8383
}
8484
}
85-
tokenizer.with_truncation(None).with_padding(None);
85+
tokenizer.with_truncation(None).unwrap().with_padding(None);
8686

8787
// Launch Tokio runtime
8888
tokio::runtime::Builder::new_multi_thread()

0 commit comments

Comments
 (0)