1616import  torch ._inductor .config 
1717import  torch .distributed  as  dist 
1818
19- from  torchchat .distributed .utils  import (
19+ from  torchtune .models .convert_weights  import  meta_to_tune 
20+ 
21+ from  torchtune .models .llama3_1 ._position_embeddings  import  Llama3ScaledRoPE 
22+ 
23+ from  torchtune .models .llama3_2_vision ._convert_weights  import  llama3_vision_meta_to_tune 
24+ 
25+ from  torchtune .training  import  set_default_dtype 
26+ 
27+ from  torchchat .distributed .logging_utils  import  SingletonLogger 
28+ 
29+ from  torchchat .distributed .utils  import  (
2030    Color  as  color ,
2131    CUDATrackTime ,
22-     init_distributed ,
2332    GPUMemoryMonitor ,
33+     init_distributed ,
2434)
25- from  torchchat .distributed .logging_utils  import  SingletonLogger 
2635
2736from  torchchat .model  import  Model , ModelArgs , ModelType , Transformer , TransformerArgs 
2837from  torchchat .model_config .model_config  import  resolve_model_config 
3645from  torchchat .utils .quantize  import  quantize_model 
3746
3847
39- from  torchtune .models .convert_weights  import  meta_to_tune 
40- 
41- from  torchtune .models .llama3_1 ._position_embeddings  import  Llama3ScaledRoPE 
42- 
43- from  torchtune .models .llama3_2_vision ._convert_weights  import  llama3_vision_meta_to_tune 
44- 
45- from  torchtune .training  import  set_default_dtype 
46- 
47- 
4848@dataclass  
4949class  BuilderArgs :
5050    checkpoint_path : Optional [Union [Path , str ]] =  None 
@@ -194,15 +194,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
194194        tp  =  getattr (args , "tp" , 1 )
195195        chpt_from  =  getattr (args , "chpt_from" , "hf" )
196196        sdp_backend_dict  =  {
197-             ' math' torch .nn .attention .SDPBackend .MATH ,
198-             ' flash_attention' torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
199-             ' efficient_attention' torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
200-             ' cudnn_attention' torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
197+             " math" torch .nn .attention .SDPBackend .MATH ,
198+             " flash_attention" torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
199+             " efficient_attention" torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
200+             " cudnn_attention" torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
201201        }
202202        attention_backend  =  sdp_backend_dict [args .attention_backend ]
203-         if  args .device  ==  "cpu"  and  (args .attention_backend  ==  "efficient_attention" 
204-                                      or  args .attention_backend  ==  "cudnn_attention" ):
205-             print (f"Warning: { args .attention_backend }  )
203+         if  args .device  ==  "cpu"  and  (
204+             args .attention_backend  ==  "efficient_attention" 
205+             or  args .attention_backend  ==  "cudnn_attention" 
206+         ):
207+             print (
208+                 f"Warning: { args .attention_backend }  
209+             )
206210            attention_backend  =  torch .nn .attention .SDPBackend .MATH 
207211        return  cls (
208212            checkpoint_dir = checkpoint_dir ,
@@ -321,7 +325,17 @@ def validate_model(
321325        if  model  is  None :
322326            return 
323327
324-         if  sum ([self .is_tiktoken , self .is_hf_tokenizer , self .is_sentencepiece , self .is_llama_3_2_mm ]) !=  1 :
328+         if  (
329+             sum (
330+                 [
331+                     self .is_tiktoken ,
332+                     self .is_hf_tokenizer ,
333+                     self .is_sentencepiece ,
334+                     self .is_llama_3_2_mm ,
335+                 ]
336+             )
337+             !=  1 
338+         ):
325339            raise  RuntimeError (f"no tokenizer was found at { self .tokenizer_path }  )
326340
327341        is_tiktoken  =  self .is_tiktoken 
@@ -333,10 +347,10 @@ def validate_model(
333347        use_hf_tokenizer  =  model .config .use_hf_tokenizer 
334348        use_other_tokenizer  =  not  (use_tiktoken  or  use_hf_tokenizer )
335349        if  (
336-             (is_tiktoken  and  not  use_tiktoken )  or 
337-             (is_hf_tokenizer  and  not  use_hf_tokenizer )  or 
338-             (is_sentencepiece  and  not  use_other_tokenizer )  or 
339-             (is_llama_3_2_mm  and  not  use_other_tokenizer )
350+             (is_tiktoken  and  not  use_tiktoken )
351+             or   (is_hf_tokenizer  and  not  use_hf_tokenizer )
352+             or   (is_sentencepiece  and  not  use_other_tokenizer )
353+             or   (is_llama_3_2_mm  and  not  use_other_tokenizer )
340354        ):
341355            raise  RuntimeError (
342356                "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}" .format (
@@ -534,6 +548,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
534548        # AOTI-compoiled model will load its own weights. 
535549        # Release weights here to avoid OOM 
536550        import  gc 
551+ 
537552        if  hasattr (model , "model" ):
538553            model .model  =  None 
539554        gc .collect ()
@@ -591,6 +606,7 @@ def _initialize_model(
591606
592607            def  do_nothing (max_batch_size , max_seq_length ):
593608                pass 
609+ 
594610            model .setup_caches  =  do_nothing 
595611
596612            model .forward  =  torch ._export .aot_load (
@@ -628,6 +644,7 @@ def do_nothing(max_batch_size, max_seq_length):
628644
629645            def  do_nothing (max_batch_size , max_seq_length ):
630646                pass 
647+ 
631648            model .setup_caches  =  do_nothing 
632649
633650            model .forward  =  aoti_compiled_model 
@@ -702,7 +719,9 @@ def do_nothing(max_batch_size, max_seq_length):
702719        logger  =  SingletonLogger .get_logger ()
703720
704721        gpu_memory_monitor  =  GPUMemoryMonitor ("cuda" )
705-         logger .info (f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset }  )
722+         logger .info (
723+             f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset }  
724+         )
706725
707726        # Model-level config 
708727        if  builder_args .params_table :
@@ -713,20 +732,16 @@ def do_nothing(max_batch_size, max_seq_length):
713732        config  =  TransformerArgs .from_params (model_config .transformer_args ["text" ])
714733        logger .info (f"Transformer Config: { config }  )
715734
716-         #TODO: Move into head of file after solving circular import 
717-         from  torchchat .distributed .checkpoint_utils  import  (
718-             load_model_weights ,
719-             )
735+         # TODO: Move into head of file after solving circular import 
736+         from  torchchat .distributed .checkpoint_utils  import  load_model_weights 
720737
721738        # Validate pipeline degree 
722739        assert  config .n_layers  %  pp_degree  ==  0 
723740
724741        # Create device mesh 
725742        device_mesh  =  dist .init_device_mesh (
726-             "cuda" ,
727-             (pp_degree , tp_degree ),
728-             mesh_dim_names = ("pp" , "tp" )
729-             )
743+             "cuda" , (pp_degree , tp_degree ), mesh_dim_names = ("pp" , "tp" )
744+         )
730745        tp_mesh  =  device_mesh ["tp" ]
731746        pp_mesh  =  device_mesh ["pp" ]
732747        logger .info (f"Created device mesh: { device_mesh } \n { tp_mesh = } { pp_mesh = }  )
@@ -755,7 +770,13 @@ def do_nothing(max_batch_size, max_seq_length):
755770        # Load weights 
756771        logger .info (f"Loading weights for { pp_rank = } { device = }  )
757772        with  CUDATrackTime () as  timer :
758-             load_model_weights (model , builder_args .distribution_path , device , config , builder_args .chpt_from )
773+             load_model_weights (
774+                 model ,
775+                 builder_args .distribution_path ,
776+                 device ,
777+                 config ,
778+                 builder_args .chpt_from ,
779+             )
759780
760781        logger .info (
761782            f"{ color .green } { timer .get_time ()} { timer .unit } { rank } { color .reset }  
@@ -769,7 +790,7 @@ def do_nothing(max_batch_size, max_seq_length):
769790        # lanes. 
770791        # TODO: bump up the lane count 
771792        pipeline_lanes  =  1 
772-         seqlen_prefill = 1024 
793+         seqlen_prefill   =   1024 
773794        with  device :
774795            model .setup_caches (1 , seqlen_prefill , cache_lanes = pipeline_lanes )
775796
0 commit comments