@@ -272,7 +272,9 @@ class TransformerArgs:
272272    norm_eps : float  =  1e-5 
273273    multiple_of : int  =  256 
274274    ffn_dim_multiplier : Optional [int ] =  None 
275+     # Select the desired tokenizer. Defaults to sentencepiece 
275276    use_tiktoken : bool  =  False 
277+     use_tokenizers : bool  =  False 
276278    max_seq_length : int  =  8192 
277279    rope_scaling : Optional [Dict [str , Any ]] =  None 
278280    # For pipeline parallel 
@@ -329,12 +331,14 @@ class ModelArgs:
329331    model_type : ModelType 
330332    transformer_args : Dict [str , Dict [str , Any ]]
331333    use_tiktoken : bool 
334+     use_tokenizers : bool 
332335
333336    def  __init__ (
334337        self ,
335338        transformer_args : Dict [str , Dict [str , Any ]],
336339        model_type : ModelType  =  ModelType .TextOnly ,
337340        use_tiktoken : bool  =  False ,
341+         use_tokenizers : bool  =  False ,
338342    ) ->  None :
339343        self ._sanity_check (transformer_args , model_type )
340344
@@ -343,6 +347,7 @@ def __init__(
343347
344348        # Model-level attributes 
345349        self .use_tiktoken  =  use_tiktoken 
350+         self .use_tokenizers  =  use_tokenizers 
346351
347352    def  _sanity_check (
348353        self ,
@@ -369,7 +374,8 @@ def from_params(cls, params_path):
369374            }
370375
371376        use_tiktoken  =  loaded_params .get ("use_tiktoken" , False )
372-         return  cls (transformer_args , model_type , use_tiktoken )
377+         use_tokenizers  =  loaded_params .get ("use_tokenizers" , False )
378+         return  cls (transformer_args , model_type , use_tiktoken , use_tokenizers )
373379
374380    @classmethod  
375381    def  from_table (cls , name : str ):
0 commit comments