1616import  torch ._inductor .config 
1717import  torch .nn  as  nn 
1818
19- from  torch .distributed .device_mesh  import  DeviceMesh 
20- from  torch .distributed .elastic .multiprocessing .errors  import  record 
21- from  torch .distributed .elastic .utils .distributed  import  get_free_port 
22- 
23- from  torchchat .distributed  import  launch_distributed , ParallelDims , parallelize_llama 
24- 
2519from  torchchat .model  import  Model , ModelArgs , ModelType 
2620
2721from  torchchat .model_config .model_config  import  resolve_model_config 
@@ -464,77 +458,11 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
464458    return  model 
465459
466460
467- def  _maybe_init_distributed (
468-     builder_args : BuilderArgs ,
469- ) ->  Tuple [Optional [DeviceMesh ], Optional [ParallelDims ]]:
470-     """ 
471-     Initialize distributed related setups if the user specified 
472-     using distributed inference. If not, this is a no-op. 
473- 
474-     Args: 
475-         builder_args (:class:`BuilderArgs`): 
476-             Command args for model building. 
477-     Returns: 
478-         Tuple[Optional[DeviceMesh], Optional[ParallelDims]]: 
479-             - The first element is an optional DeviceMesh object, 
480-             which which describes the mesh topology of devices for the DTensor. 
481-             - The second element is an optional ParallelDims object, 
482-             which represents the parallel dimensions configuration. 
483-     """ 
484-     if  not  builder_args .use_distributed :
485-         return  None , None 
486-     dist_config  =  "llama3_8B.toml"   # TODO - integrate with chat cmd line 
487- 
488-     world_mesh , parallel_dims  =  launch_distributed (dist_config )
489- 
490-     assert  (
491-         world_mesh  is  not   None  and  parallel_dims  is  not   None 
492-     ), f"failed to launch distributed using { dist_config }  " 
493- 
494-     return  world_mesh , parallel_dims 
495- 
496- 
497- def  _maybe_parallelize_model (
498-     model : nn .Module ,
499-     builder_args : BuilderArgs ,
500-     world_mesh : DeviceMesh ,
501-     parallel_dims : ParallelDims ,
502- ) ->  nn .Module :
503-     """ 
504-     We parallelize the module and load the distributed checkpoint to the model 
505-     if the user specifies using distributed inference. If not, this is a no-op. 
506- 
507-     Args: 
508-         model (:class:`nn.Module`): 
509-             Module to be parallelized. 
510-         builder_args (:class:`BuilderArgs`): 
511-             Command args for model building. 
512-         world_mesh (:class:`DeviceMesh`): 
513-             Object which describes the mesh topology 
514-             of devices for the DTensor. 
515-         parallel_dims (:class:`ParallelDims`): 
516-             Object which represents the parallel dimensions configuration. 
517-     Returns: 
518-         A :class:`nn.Module` object which is parallelized and checkpoint loaded 
519-         if the user specifies using distributed inference. 
520-     """ 
521-     if  world_mesh  is  None :
522-         return  model 
523-     assert  parallel_dims  is  not   None 
524-     print ("Applying model parallel to model ..." )
525-     parallelize_llama (model , world_mesh , parallel_dims )
526-     return  load_checkpoints_to_model (model , builder_args , world_mesh )
527- 
528- 
529461def  _load_model (builder_args : BuilderArgs ) ->  Model :
530-     # world_mesh, parallel_dims = _maybe_init_distributed(builder_args) 
531462    if  builder_args .gguf_path :
532463        model  =  _load_model_gguf (builder_args )
533-     # elif builder_args.use_distributed: 
534-     #    model = _init_model_on_meta_device(builder_args) 
535464    else :
536465        model  =  _load_model_default (builder_args )
537-     # model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims) 
538466
539467    if  builder_args .dso_path  or  builder_args .aoti_package_path :
540468        # AOTI-compoiled model will load its own weights. 
@@ -706,4 +634,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
706634        return  "TikToken" 
707635    if  tokenizers :
708636        return  "Tokenizers" 
709-     return  "SentencePiece" 
637+     return  "SentencePiece" 
0 commit comments