@@ -51,6 +51,15 @@ pub enum AddBos {
5151 Never ,
5252}
5353
54+ /// How to determine if we should tokenize special tokens
55+ #[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
56+ pub enum Special {
57+ /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
58+ Tokenize ,
59+ /// Treat special and/or control tokens as plaintext.
60+ Plaintext ,
61+ }
62+
5463unsafe impl Send for LlamaModel { }
5564
5665unsafe impl Sync for LlamaModel { }
@@ -71,10 +80,11 @@ impl LlamaModel {
7180 /// Get all tokens in the model.
7281 pub fn tokens (
7382 & self ,
83+ special : Special ,
7484 ) -> impl Iterator < Item = ( LlamaToken , Result < String , TokenToStringError > ) > + ' _ {
7585 ( 0 ..self . n_vocab ( ) )
7686 . map ( LlamaToken :: new)
77- . map ( |llama_token| ( llama_token, self . token_to_str ( llama_token) ) )
87+ . map ( move |llama_token| ( llama_token, self . token_to_str ( llama_token, special ) ) )
7888 }
7989
8090 /// Get the beginning of stream token.
@@ -103,27 +113,27 @@ impl LlamaModel {
103113 /// # Errors
104114 ///
105115 /// See [`TokenToStringError`] for more information.
106- pub fn token_to_str ( & self , token : LlamaToken ) -> Result < String , TokenToStringError > {
107- self . token_to_str_with_size ( token, 32 )
116+ pub fn token_to_str ( & self , token : LlamaToken , special : Special ) -> Result < String , TokenToStringError > {
117+ self . token_to_str_with_size ( token, 32 , special )
108118 }
109119
110120 /// Convert single token to bytes.
111121 ///
112122 /// # Errors
113123 ///
114124 /// See [`TokenToStringError`] for more information.
115- pub fn token_to_bytes ( & self , token : LlamaToken ) -> Result < Vec < u8 > , TokenToStringError > {
116- self . token_to_bytes_with_size ( token, 32 )
125+ pub fn token_to_bytes ( & self , token : LlamaToken , special : Special ) -> Result < Vec < u8 > , TokenToStringError > {
126+ self . token_to_bytes_with_size ( token, 32 , special )
117127 }
118128
119129 /// Convert a vector of tokens to a single string.
120130 ///
121131 /// # Errors
122132 ///
123133 /// See [`TokenToStringError`] for more information.
124- pub fn tokens_to_str ( & self , tokens : & [ LlamaToken ] ) -> Result < String , TokenToStringError > {
134+ pub fn tokens_to_str ( & self , tokens : & [ LlamaToken ] , special : Special ) -> Result < String , TokenToStringError > {
125135 let mut builder = String :: with_capacity ( tokens. len ( ) * 4 ) ;
126- for str in tokens. iter ( ) . copied ( ) . map ( |t| self . token_to_str ( t) ) {
136+ for str in tokens. iter ( ) . copied ( ) . map ( |t| self . token_to_str ( t, special ) ) {
127137 builder += & str?;
128138 }
129139 Ok ( builder)
@@ -236,8 +246,9 @@ impl LlamaModel {
236246 & self ,
237247 token : LlamaToken ,
238248 buffer_size : usize ,
249+ special : Special ,
239250 ) -> Result < String , TokenToStringError > {
240- let bytes = self . token_to_bytes_with_size ( token, buffer_size) ?;
251+ let bytes = self . token_to_bytes_with_size ( token, buffer_size, special ) ?;
241252 Ok ( String :: from_utf8 ( bytes) ?)
242253 }
243254
@@ -259,11 +270,13 @@ impl LlamaModel {
259270 & self ,
260271 token : LlamaToken ,
261272 buffer_size : usize ,
273+ special : Special ,
262274 ) -> Result < Vec < u8 > , TokenToStringError > {
263275 if token == self . token_nl ( ) {
264276 return Ok ( String :: from ( "\n " ) . into_bytes ( ) ) ;
265277 }
266278
279+ // unsure what to do with this in the face of the 'special' arg
267280 match self . token_type ( token) {
268281 LlamaTokenType :: Normal | LlamaTokenType :: UserDefined => { }
269282 LlamaTokenType :: Control => {
@@ -279,12 +292,17 @@ impl LlamaModel {
279292 }
280293 }
281294
295+ let special = match special {
296+ Special :: Tokenize => true ,
297+ Special :: Plaintext => false ,
298+ } ;
299+
282300 let string = CString :: new ( vec ! [ b'*' ; buffer_size] ) . expect ( "no null" ) ;
283301 let len = string. as_bytes ( ) . len ( ) ;
284302 let len = c_int:: try_from ( len) . expect ( "length fits into c_int" ) ;
285303 let buf = string. into_raw ( ) ;
286304 let size = unsafe {
287- llama_cpp_sys_2:: llama_token_to_piece ( self . model . as_ptr ( ) , token. 0 , buf, len)
305+ llama_cpp_sys_2:: llama_token_to_piece ( self . model . as_ptr ( ) , token. 0 , buf, len, special )
288306 } ;
289307
290308 match size {
0 commit comments