44# This source code is licensed under the license found in the 
55# LICENSE file in the root directory of this source tree. 
66
7+ import  argparse 
78import  os 
89import  sys 
9- import  time 
1010from  dataclasses  import  dataclass 
1111from  pathlib  import  Path 
1212from  typing  import  Any , Dict , Optional , Tuple , Union 
2121except  ImportError :
2222    pass 
2323
24- from  distributed  import  (
25-     init_distributed ,
26-     launch_distributed ,
27-     ParallelDims ,
28-     parallelize_llama ,
29- )
24+ from  distributed  import  launch_distributed , ParallelDims , parallelize_llama 
3025
3126from  torch .distributed .device_mesh  import  DeviceMesh 
3227
@@ -101,7 +96,7 @@ def __post_init__(self):
10196            self .prefill_possible  =  True 
10297
10398    @classmethod  
104-     def  from_args (cls , args ):   #  -> BuilderArgs:
99+     def  from_args (cls , args :  argparse . Namespace )  ->  " BuilderArgs" 
105100        # Handle disabled checkpoint_dir option 
106101        checkpoint_dir  =  None 
107102        if  hasattr (args , "checkpoint_dir" ):
@@ -183,7 +178,7 @@ def from_args(cls, args):  # -> BuilderArgs:
183178        )
184179
185180    @classmethod  
186-     def  from_speculative_args (cls , args ):   #  -> BuilderArgs:
181+     def  from_speculative_args (cls , args :  argparse . Namespace )  ->  " BuilderArgs" 
187182        speculative_builder_args  =  BuilderArgs .from_args (args )
188183        # let's limit multi-checkpoint to checker 
189184        speculative_builder_args .checkpoint_dir  =  None 
@@ -229,7 +224,7 @@ def __post_init__(self):
229224
230225    def  validate_model (
231226        self ,
232-         model : Model ,
227+         model : Optional [ Model ] ,
233228        model_description : str  =  "model" ,
234229    ) ->  None :
235230        if  model  is  None :
@@ -250,10 +245,21 @@ def validate_model(
250245        return 
251246
252247    @classmethod  
253-     def  from_args (cls , args ):  # -> TokenizerArgs: 
254-         is_sentencepiece  =  False 
255-         is_tiktoken  =  False 
256- 
248+     def  from_args (cls , args : argparse .Namespace ) ->  "TokenizerArgs" :
249+         """ 
250+         Create a TokenizerArgs object from command line arguments. 
251+         Specifically, `tokenizer_path` is resolved with precedence: 
252+           * From Explicitly provided tokenizer_path 
253+           * Resolve via model_config identified by args.model 
254+           * Look in the directory of args.checkpoint_path for tokenizer.model 
255+           * Look in the directory of args.checkpoint_dir for tokenizer.model 
256+ 
257+         Args: 
258+             args (argparse.Namespace): The command line arguments. 
259+ 
260+         Returns: 
261+             TokenizerArgs: A TokenizerArgs object. 
262+         """ 
257263        if  args .tokenizer_path :
258264            tokenizer_path  =  args .tokenizer_path 
259265        elif  args .model :  # Using a named, well-known model 
@@ -263,7 +269,6 @@ def from_args(cls, args):  # -> TokenizerArgs:
263269                /  model_config .name 
264270                /  model_config .tokenizer_file 
265271            )
266- 
267272        elif  args .checkpoint_path :
268273            tokenizer_path  =  args .checkpoint_path .parent  /  "tokenizer.model" 
269274        elif  hasattr (args , "checkpoint_dir" ) and  args .checkpoint_dir :
@@ -276,12 +281,7 @@ def from_args(cls, args):  # -> TokenizerArgs:
276281                f"did not find tokenizer at { os .path .abspath (tokenizer_path )}  
277282            )
278283
279-         return  cls (
280-             tokenizer_path = tokenizer_path ,
281-             is_sentencepiece = is_sentencepiece ,
282-             is_tiktoken = is_tiktoken ,
283-             t = None ,
284-         )
284+         return  cls (tokenizer_path = tokenizer_path )
285285
286286
287287def  _initialize_tokenizer (tokenizer_args : TokenizerArgs ):
@@ -299,7 +299,7 @@ def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
299299
300300
301301# TODO: remove these once ET supports _weight_int4pack_mm 
302- def  _set_gguf_kwargs (builder_args , is_et , context : str ):
302+ def  _set_gguf_kwargs (builder_args :  BuilderArgs , is_et :  bool , context : str )  ->   None :
303303    assert  context  in  ["export" , "generate" ]
304304    assert  builder_args .gguf_kwargs  is  None 
305305
@@ -312,11 +312,11 @@ def _set_gguf_kwargs(builder_args, is_et, context: str):
312312        builder_args .gguf_kwargs ["load_as_quantized" ] =  False 
313313
314314
315- def  _unset_gguf_kwargs (builder_args ) :
315+ def  _unset_gguf_kwargs (builder_args :  BuilderArgs )  ->   None :
316316    builder_args .gguf_kwargs  =  None 
317317
318318
319- def  _init_model_on_meta_device (builder_args ) :
319+ def  _init_model_on_meta_device (builder_args :  BuilderArgs )  ->   Model :
320320    with  torch .device ("meta" ):
321321        if  builder_args .params_path :
322322            return  Model .from_params (builder_args .params_path )
@@ -326,7 +326,7 @@ def _init_model_on_meta_device(builder_args):
326326            return  Model .from_name (builder_args .checkpoint_path .parent .name )
327327
328328
329- def  _load_model_gguf (builder_args ,  only_config = False ) :
329+ def  _load_model_gguf (builder_args :  BuilderArgs )  ->   Model :
330330    assert  builder_args .gguf_path 
331331    if  builder_args .gguf_kwargs  is  None :
332332        kwargs  =  {}
@@ -336,10 +336,10 @@ def _load_model_gguf(builder_args, only_config=False):
336336    return  model 
337337
338338
339- def  _load_model_default (builder_args ,  only_config = False ) :
339+ def  _load_model_default (builder_args :  BuilderArgs )  ->   Model :
340340    assert  not  builder_args .gguf_path 
341341
342-     model  =  _init_model_on_meta_device (builder_args )
342+     model :  Model  =  _init_model_on_meta_device (builder_args )
343343
344344    if  builder_args .params_table  and  builder_args .params_table .endswith ("Tune" ):
345345        print ("Loading Tune checkpoint" )
@@ -459,7 +459,7 @@ def _maybe_parellelize_model(
459459    return  load_checkpoints_to_model (model , builder_args , world_mesh )
460460
461461
462- def  _load_model (builder_args ,  only_config = False ) :
462+ def  _load_model (builder_args :  BuilderArgs )  ->   Model :
463463    world_mesh , parallel_dims  =  _maybe_init_distributed (builder_args )
464464    if  builder_args .gguf_path :
465465        model  =  _load_model_gguf (builder_args )
@@ -474,12 +474,12 @@ def _load_model(builder_args, only_config=False):
474474
475475
476476def  _initialize_model (
477-     builder_args ,
477+     builder_args :  BuilderArgs ,
478478    quantize ,
479479    tokenizer = None ,
480480    max_seq_length = None ,
481481    support_tensor_subclass : bool  =  True ,
482- ):
482+ )  ->   Model :
483483    print ("Loading model..." )
484484
485485    if  builder_args .gguf_path  and  (builder_args .dso_path  or  builder_args .pte_path ):
@@ -505,7 +505,7 @@ def _initialize_model(
505505        # ), "quantize not valid for exported DSO model. Specify quantization during export." 
506506
507507        with  measure_time ("Time to load model: {time:.02f} seconds" ):
508-             model  =  _load_model (builder_args ,  only_config = True )
508+             model  =  _load_model (builder_args )
509509            device_sync (device = builder_args .device )
510510
511511        try :
@@ -532,7 +532,7 @@ def _initialize_model(
532532        # ), "quantize not valid for exported PTE model. Specify quantization during export." 
533533
534534        with  measure_time ("Time to load model: {time:.02f} seconds" ):
535-             model  =  _load_model (builder_args ,  only_config = True )
535+             model  =  _load_model (builder_args )
536536            device_sync (device = builder_args .device )
537537
538538        try :
0 commit comments