@@ -272,7 +272,9 @@ class TransformerArgs:
272272 norm_eps : float = 1e-5
273273 multiple_of : int = 256
274274 ffn_dim_multiplier : Optional [int ] = None
275+ # Select the desired tokenizer. Defaults to sentencepiece
275276 use_tiktoken : bool = False
277+ use_tokenizers : bool = False
276278 max_seq_length : int = 8192
277279 rope_scaling : Optional [Dict [str , Any ]] = None
278280 # For pipeline parallel
@@ -329,12 +331,14 @@ class ModelArgs:
329331 model_type : ModelType
330332 transformer_args : Dict [str , Dict [str , Any ]]
331333 use_tiktoken : bool
334+ use_tokenizers : bool
332335
333336 def __init__ (
334337 self ,
335338 transformer_args : Dict [str , Dict [str , Any ]],
336339 model_type : ModelType = ModelType .TextOnly ,
337340 use_tiktoken : bool = False ,
341+ use_tokenizers : bool = False ,
338342 ) -> None :
339343 self ._sanity_check (transformer_args , model_type )
340344
@@ -343,6 +347,7 @@ def __init__(
343347
344348 # Model-level attributes
345349 self .use_tiktoken = use_tiktoken
350+ self .use_tokenizers = use_tokenizers
346351
347352 def _sanity_check (
348353 self ,
@@ -369,7 +374,8 @@ def from_params(cls, params_path):
369374 }
370375
371376 use_tiktoken = loaded_params .get ("use_tiktoken" , False )
372- return cls (transformer_args , model_type , use_tiktoken )
377+ use_tokenizers = loaded_params .get ("use_tokenizers" , False )
378+ return cls (transformer_args , model_type , use_tiktoken , use_tokenizers )
373379
374380 @classmethod
375381 def from_table (cls , name : str ):
0 commit comments