1
1
use std:: mem:: take;
2
+ use std:: slice:: from_ref;
2
3
use tokenizers:: DecoderWrapper :: { BPE , ByteLevel , Metaspace , WordPiece , CTC , Sequence } ;
3
4
use tokenizers:: { Error , Tokenizer } ;
4
5
use unicode_segmentation:: UnicodeSegmentation ;
@@ -25,15 +26,15 @@ impl Decoder {
25
26
. expect ( "Tokenizer setup error" ) . get_ids ( ) . first ( ) . unwrap ( ) ;
26
27
Decoder {
27
28
single_tok_id : prefix_id,
28
- single_tok : tokenizer. decode ( vec ! [ prefix_id] , false ) . unwrap ( ) ,
29
+ single_tok : tokenizer. decode ( from_ref ( & prefix_id) , false ) . unwrap ( ) ,
29
30
tokenizer,
30
31
seq2seq,
31
32
eos_token_id,
32
33
skip_special_toks,
33
34
}
34
35
}
35
36
36
- fn decode_full ( & self , ids : Vec < u32 > ) -> Result < String , InferError > {
37
+ fn decode_full ( & self , ids : & [ u32 ] ) -> Result < String , InferError > {
37
38
self . tokenizer . decode ( ids, self . skip_special_toks ) . map_err ( Error :: into)
38
39
}
39
40
@@ -48,7 +49,7 @@ impl Decoder {
48
49
if ( first && self . seq2seq ) || ( last && matches ! [ decoder, Some ( BPE ( _) ) ] )
49
50
|| matches ! [ decoder, Some ( ByteLevel ( _) | CTC ( _) ) ] {
50
51
// In these cases we don't need to do anything special for "continuation"
51
- let mut text = self . decode_full ( ids) ?;
52
+ let mut text = self . decode_full ( & * ids) ?;
52
53
text. truncate ( text. trim_end_matches ( '�' ) . len ( ) ) ; // Avoid add'l allocation
53
54
return Ok ( text)
54
55
}
@@ -59,7 +60,7 @@ impl Decoder {
59
60
// For these, the first token in the sequence is treated differently,
60
61
// so we add and then strip a placeholder token.
61
62
ids. insert ( 0 , self . single_tok_id ) ;
62
- let result = self . decode_full ( ids) ?;
63
+ let result = self . decode_full ( & * ids) ?;
63
64
let mut text = result. strip_prefix ( & self . single_tok ) . ok_or_else (
64
65
|| DetokenizationError ( "Unexpected" . into ( ) )
65
66
) ?. to_string ( ) ;
@@ -68,10 +69,53 @@ impl Decoder {
68
69
} ,
69
70
Some ( BPE ( _) ) => {
70
71
ids. push ( self . single_tok_id ) ;
71
- let result = self . decode_full ( ids) ?;
72
- Ok ( result. strip_suffix ( & self . single_tok )
73
- . ok_or_else ( || DetokenizationError ( "Unexpected" . into ( ) ) )
74
- ?. to_string ( ) )
72
+ let result = self . decode_full ( & * ids) ?;
73
+ Ok ( result. strip_suffix ( & self . single_tok ) . ok_or_else (
74
+ || DetokenizationError ( "Unexpected" . into ( ) )
75
+ ) ?. to_string ( ) )
76
+ } ,
77
+ None => {
78
+ // Just prepend a space
79
+ Ok ( format ! ( " {}" , self . decode_full( & * ids) ?) )
80
+ } ,
81
+ Some ( tok) => {
82
+ Err ( DetokenizationError ( format ! ( "Unsupported tokenizer type: {:?}" , tok) ) )
83
+ }
84
+ }
85
+ }
86
+
87
+ pub ( crate ) fn decode_ref (
88
+ & self , ids : & [ u32 ] , first : bool , last : bool ,
89
+ ) -> Result < String , InferError > {
90
+ let decoder = self . tokenizer . get_decoder ( ) ;
91
+ if ( first && self . seq2seq ) || ( last && matches ! [ decoder, Some ( BPE ( _) ) ] )
92
+ || matches ! [ decoder, Some ( ByteLevel ( _) | CTC ( _) ) ] {
93
+ // In these cases we don't need to do anything special for "continuation"
94
+ let mut text = self . decode_full ( ids) ?;
95
+ text. truncate ( text. trim_end_matches ( '�' ) . len ( ) ) ; // Avoid add'l allocation
96
+ return Ok ( text)
97
+ }
98
+ // How we handle continuation depends on the specific decoder's behaviour,
99
+ // see each one's implementation of decode_chain in the tokenizers library.
100
+ match self . tokenizer . get_decoder ( ) {
101
+ Some ( Metaspace ( _) | WordPiece ( _) | Sequence ( _) ) => {
102
+ // For these, the first token in the sequence is treated differently,
103
+ // so we add and then strip a placeholder token.
104
+ let ids = [ from_ref ( & 0 ) , ids] . concat ( ) ;
105
+ let result = self . decode_full ( & * ids) ?;
106
+ let mut text = result. strip_prefix ( & self . single_tok ) . ok_or_else (
107
+ || DetokenizationError ( "Unexpected" . into ( ) )
108
+ ) ?. to_string ( ) ;
109
+ text. truncate ( text. trim_end_matches ( '�' ) . len ( ) ) ; // Avoid add'l allocation
110
+ Ok ( text)
111
+ } ,
112
+ Some ( BPE ( _) ) => {
113
+ let ids = [ ids, from_ref ( & self . single_tok_id ) ] . concat ( ) ;
114
+ // ids.push(self.single_tok_id);
115
+ let result = self . decode_full ( & * ids) ?;
116
+ Ok ( result. strip_suffix ( & self . single_tok ) . ok_or_else (
117
+ || DetokenizationError ( "Unexpected" . into ( ) )
118
+ ) ?. to_string ( ) )
75
119
} ,
76
120
None => {
77
121
// Just prepend a space
@@ -158,7 +202,7 @@ pub(crate) struct IncrementalFirstDiffDecoder {
158
202
159
203
impl IncrementalDecoder for IncrementalFirstDiffDecoder {
160
204
fn next ( & mut self , token : u32 , decoder : & Decoder ) -> Result < String , InferError > {
161
- let text = decoder. decode ( vec ! [ token] , self . first , false ) ?;
205
+ let text = decoder. decode_ref ( from_ref ( & token) , self . first , false ) ?;
162
206
self . first = false ;
163
207
self . output += & text;
164
208
Ok ( text)
@@ -182,7 +226,7 @@ impl IncrementalDecoder for IncrementalLastDiffDecoder {
182
226
fn next ( & mut self , token : u32 , decoder : & Decoder ) -> Result < String , InferError > {
183
227
let text = self . next_id . map_or_else (
184
228
|| Ok ( String :: new ( ) ) ,
185
- |id| decoder. decode ( vec ! [ id ] , true , false )
229
+ |ref id| decoder. decode_ref ( from_ref ( id ) , true , false )
186
230
) ?;
187
231
self . next_id = Some ( token) ;
188
232
self . output += & text;
@@ -192,7 +236,7 @@ impl IncrementalDecoder for IncrementalLastDiffDecoder {
192
236
fn flush ( & mut self , decoder : & Decoder ) -> Result < String , InferError > {
193
237
let text = self . next_id . map_or_else (
194
238
|| Ok ( String :: new ( ) ) ,
195
- |id| decoder. decode_full ( vec ! [ id ] )
239
+ |ref id| decoder. decode_full ( from_ref ( id ) )
196
240
) ?;
197
241
self . next_id = None ;
198
242
self . output += & text;
@@ -219,7 +263,7 @@ impl IncrementalDecoder for IncrementalDeDupDecoder {
219
263
return Ok ( String :: new ( ) )
220
264
}
221
265
self . last_id = Some ( token) ;
222
- let text = decoder. decode_full ( vec ! [ token] ) ?;
266
+ let text = decoder. decode_full ( from_ref ( & token) ) ?;
223
267
self . output += & text;
224
268
Ok ( text)
225
269
}
@@ -257,11 +301,11 @@ impl IncrementalBLDecoder {
257
301
impl IncrementalDecoder for IncrementalBLDecoder {
258
302
fn next ( & mut self , token : u32 , decoder : & Decoder ) -> Result < String , InferError > {
259
303
self . id_buffer . push ( token) ;
260
- let mut buffer = self . id_buffer . clone ( ) ;
304
+ let buffer = & * self . id_buffer ;
261
305
let text = if self . first_diff && !self . first {
262
306
// Prepend placeholder token to avoid first-token differences
263
- buffer. insert ( 0 , decoder. single_tok_id ) ;
264
- let result = decoder. decode_full ( buffer) ?;
307
+ let buffer = [ from_ref ( & decoder. single_tok_id ) , buffer ] . concat ( ) ;
308
+ let result = decoder. decode_full ( & * buffer) ?;
265
309
result. strip_prefix ( & decoder. single_tok ) . ok_or_else (
266
310
|| DetokenizationError ( "Unexpected" . into ( ) )
267
311
) ?. to_string ( )
@@ -291,7 +335,7 @@ impl IncrementalDecoder for IncrementalBLDecoder {
291
335
}
292
336
fn flush ( & mut self , decoder : & Decoder ) -> Result < String , InferError > {
293
337
if !self . id_buffer . is_empty ( ) {
294
- let last = decoder. decode_full ( self . id_buffer . clone ( ) ) ?;
338
+ let last = decoder. decode_full ( & * self . id_buffer ) ?;
295
339
let last = last. trim_end_matches ( '�' ) ;
296
340
self . output += last;
297
341
self . str_buffer . push_str ( last) ;
0 commit comments