@@ -270,7 +270,9 @@ class TransformerArgs:
270270 norm_eps : float = 1e-5
271271 multiple_of : int = 256
272272 ffn_dim_multiplier : Optional [int ] = None
273+ # Select the desired tokenizer. Defaults to sentencepiece
273274 use_tiktoken : bool = False
275+ use_tokenizers : bool = False
274276 max_seq_length : int = 8192
275277 rope_scaling : Optional [Dict [str , Any ]] = None
276278 # For pipeline parallel
@@ -327,12 +329,14 @@ class ModelArgs:
327329 model_type : ModelType
328330 transformer_args : Dict [str , Dict [str , Any ]]
329331 use_tiktoken : bool
332+ use_tokenizers : bool
330333
331334 def __init__ (
332335 self ,
333336 transformer_args : Dict [str , Dict [str , Any ]],
334337 model_type : ModelType = ModelType .TextOnly ,
335338 use_tiktoken : bool = False ,
339+ use_tokenizers : bool = False ,
336340 ) -> None :
337341 self ._sanity_check (transformer_args , model_type )
338342
@@ -341,6 +345,7 @@ def __init__(
341345
342346 # Model-level attributes
343347 self .use_tiktoken = use_tiktoken
348+ self .use_tokenizers = use_tokenizers
344349
345350 def _sanity_check (
346351 self ,
@@ -367,7 +372,8 @@ def from_params(cls, params_path):
367372 }
368373
369374 use_tiktoken = loaded_params .get ("use_tiktoken" , False )
370- return cls (transformer_args , model_type , use_tiktoken )
375+ use_tokenizers = loaded_params .get ("use_tokenizers" , False )
376+ return cls (transformer_args , model_type , use_tiktoken , use_tokenizers )
371377
372378 @classmethod
373379 def from_table (cls , name : str ):
0 commit comments