@@ -190,6 +190,7 @@ class TokenizerArgs:
190190 tokenizer_path : Optional [Union [Path , str ]] = None
191191 is_sentencepiece : bool = False
192192 is_tiktoken : bool = False
193+ is_tokenizers : bool = False
193194 t : Optional [Any ] = None
194195
195196 def __post_init__ (self ):
@@ -199,6 +200,7 @@ def __post_init__(self):
199200 self .t = TiktokenTokenizer (model_path = str (self .tokenizer_path ))
200201 self .is_tiktoken = True
201202 self .is_sentencepiece = False
203+ self .is_tokenizers = False
202204 return
203205 except :
204206 pass
@@ -209,12 +211,25 @@ def __post_init__(self):
209211 self .t = SentencePieceProcessor (model_file = str (self .tokenizer_path ))
210212 self .is_tiktoken = False
211213 self .is_sentencepiece = True
214+ self .is_tokenizers = False
215+ return
216+ except :
217+ pass
218+
219+ try :
220+ from tokenizer .tokenizers import TokenizersTokenizer
221+
222+ self .t = TokenizersTokenizer (str (self .tokenizer_path ))
223+ self .is_tiktoken = False
224+ self .is_sentencepiece = False
225+ self .is_tokenizers = True
212226 return
213227 except :
214228 pass
215229
216230 self .is_tiktoken = False
217231 self .is_sentencepiece = False
232+ self .is_tokenizers = False
218233 self .t = None
219234 return
220235
@@ -226,16 +241,27 @@ def validate_model(
226241 if model is None :
227242 return
228243
229- if self .is_tiktoken == self .is_sentencepiece :
244+ if len ( list ( filter ( lambda x : x , [ self .is_tiktoken , self . is_tokenizers , self .is_sentencepiece ]))) != 1 :
230245 raise RuntimeError (f"no tokenizer was found at { self .tokenizer_path } " )
231246
232247 is_tiktoken = self .is_tiktoken
233248 is_sentencepiece = self .is_sentencepiece
249+ is_tokenizers = self .is_tokenizers
234250 use_tiktoken = model .config .use_tiktoken
251+ use_tokenizers = model .config .use_tokenizers
252+ use_sentencepiece = not (use_tiktoken or use_tokenizers )
235253
236- if not (is_tiktoken == use_tiktoken ) or not (is_sentencepiece != use_tiktoken ):
254+ if (
255+ (is_tiktoken and not use_tiktoken ) or
256+ (is_tokenizers and not use_tokenizers ) or
257+ (is_sentencepiece and not use_sentencepiece )
258+ ):
237259 raise RuntimeError (
238- f"model-specified tokenizer ({ tokenizer_setting_to_name (use_tiktoken )} ) does not match provided tokenizer ({ tokenizer_setting_to_name (is_tiktoken )} ) for { model_description } "
260+ "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}" .format (
261+ tokenizer_setting_to_name (use_tiktoken , use_tokenizers ),
262+ tokenizer_setting_to_name (is_tiktoken , is_tokenizers ),
263+ model_description ,
264+ )
239265 )
240266
241267 return
@@ -591,5 +617,9 @@ def _initialize_model(
591617 return model
592618
593619
594- def tokenizer_setting_to_name (tiktoken : bool = False ) -> str :
595- return "TikToken" if tiktoken else "SentencePiece"
620+ def tokenizer_setting_to_name (tiktoken : bool , tokenizers : bool ) -> str :
621+ if tiktoken :
622+ return "TikToken"
623+ if tokenizers :
624+ return "Tokenizers"
625+ return "SentencePiece"
0 commit comments