@@ -51,6 +51,15 @@ pub enum AddBos {
51
51
Never ,
52
52
}
53
53
54
+ /// How to determine if we should tokenize special tokens
55
+ #[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
56
+ pub enum Special {
57
+ /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
58
+ Tokenize ,
59
+ /// Treat special and/or control tokens as plaintext.
60
+ Plaintext ,
61
+ }
62
+
54
63
unsafe impl Send for LlamaModel { }
55
64
56
65
unsafe impl Sync for LlamaModel { }
@@ -71,10 +80,11 @@ impl LlamaModel {
71
80
/// Get all tokens in the model.
72
81
pub fn tokens (
73
82
& self ,
83
+ special : Special ,
74
84
) -> impl Iterator < Item = ( LlamaToken , Result < String , TokenToStringError > ) > + ' _ {
75
85
( 0 ..self . n_vocab ( ) )
76
86
. map ( LlamaToken :: new)
77
- . map ( |llama_token| ( llama_token, self . token_to_str ( llama_token) ) )
87
+ . map ( move |llama_token| ( llama_token, self . token_to_str ( llama_token, special ) ) )
78
88
}
79
89
80
90
/// Get the beginning of stream token.
@@ -103,8 +113,8 @@ impl LlamaModel {
103
113
/// # Errors
104
114
///
105
115
/// See [`TokenToStringError`] for more information.
106
- pub fn token_to_str ( & self , token : LlamaToken ) -> Result < String , TokenToStringError > {
107
- self . token_to_str_with_size ( token, 32 )
116
+ pub fn token_to_str ( & self , token : LlamaToken , special : Special ) -> Result < String , TokenToStringError > {
117
+ self . token_to_str_with_size ( token, 32 , special )
108
118
}
109
119
110
120
/// Convert single token to bytes.
@@ -121,9 +131,9 @@ impl LlamaModel {
121
131
/// # Errors
122
132
///
123
133
/// See [`TokenToStringError`] for more information.
124
- pub fn tokens_to_str ( & self , tokens : & [ LlamaToken ] ) -> Result < String , TokenToStringError > {
134
+ pub fn tokens_to_str ( & self , tokens : & [ LlamaToken ] , special : Special ) -> Result < String , TokenToStringError > {
125
135
let mut builder = String :: with_capacity ( tokens. len ( ) * 4 ) ;
126
- for str in tokens. iter ( ) . copied ( ) . map ( |t| self . token_to_str ( t) ) {
136
+ for str in tokens. iter ( ) . copied ( ) . map ( |t| self . token_to_str ( t, special ) ) {
127
137
builder += & str?;
128
138
}
129
139
Ok ( builder)
@@ -236,6 +246,7 @@ impl LlamaModel {
236
246
& self ,
237
247
token : LlamaToken ,
238
248
buffer_size : usize ,
249
+ special : Special ,
239
250
) -> Result < String , TokenToStringError > {
240
251
let bytes = self . token_to_bytes_with_size ( token, buffer_size) ?;
241
252
Ok ( String :: from_utf8 ( bytes) ?)
@@ -264,6 +275,7 @@ impl LlamaModel {
264
275
return Ok ( String :: from ( "\n " ) . into_bytes ( ) ) ;
265
276
}
266
277
278
+ // unsure what to do with this in the face of the 'special' arg
267
279
match self . token_type ( token) {
268
280
LlamaTokenType :: Normal | LlamaTokenType :: UserDefined => { }
269
281
LlamaTokenType :: Control => {
@@ -279,12 +291,17 @@ impl LlamaModel {
279
291
}
280
292
}
281
293
294
+ let special = match special {
295
+ Special :: Tokenize => true ,
296
+ Special :: Plaintext => false ,
297
+ } ;
298
+
282
299
let string = CString :: new ( vec ! [ b'*' ; buffer_size] ) . expect ( "no null" ) ;
283
300
let len = string. as_bytes ( ) . len ( ) ;
284
301
let len = c_int:: try_from ( len) . expect ( "length fits into c_int" ) ;
285
302
let buf = string. into_raw ( ) ;
286
303
let size = unsafe {
287
- llama_cpp_sys_2:: llama_token_to_piece ( self . model . as_ptr ( ) , token. 0 , buf, len)
304
+ llama_cpp_sys_2:: llama_token_to_piece ( self . model . as_ptr ( ) , token. 0 , buf, len, special )
288
305
} ;
289
306
290
307
match size {
0 commit comments