1616import  torch ._inductor .config 
1717import  torch .distributed  as  dist 
1818
19- from  torchtune .models .convert_weights  import  meta_to_tune 
20- 
21- from  torchtune .models .llama3_1 ._position_embeddings  import  Llama3ScaledRoPE 
22- 
23- from  torchtune .models .llama3_2_vision ._convert_weights  import  llama3_vision_meta_to_tune 
24- 
25- from  torchtune .training  import  set_default_dtype 
26- 
27- from  torchchat .distributed .logging_utils  import  SingletonLogger 
28- 
29- from  torchchat .distributed .utils  import  (
19+ from  torchchat .distributed .utils  import (
3020    Color  as  color ,
3121    CUDATrackTime ,
32-     GPUMemoryMonitor ,
3322    init_distributed ,
23+     GPUMemoryMonitor ,
3424)
25+ from  torchchat .distributed .logging_utils  import  SingletonLogger 
3526
3627from  torchchat .model  import  Model , ModelArgs , ModelType , Transformer , TransformerArgs 
3728from  torchchat .model_config .model_config  import  resolve_model_config 
4536from  torchchat .utils .quantize  import  quantize_model 
4637
4738
39+ from  torchtune .models .convert_weights  import  meta_to_tune 
40+ 
41+ from  torchtune .models .llama3_1 ._position_embeddings  import  Llama3ScaledRoPE 
42+ 
43+ from  torchtune .models .llama3_2_vision ._convert_weights  import  llama3_vision_meta_to_tune 
44+ 
45+ from  torchtune .training  import  set_default_dtype 
46+ 
47+ 
4848@dataclass  
4949class  BuilderArgs :
5050    checkpoint_path : Optional [Union [Path , str ]] =  None 
@@ -189,19 +189,15 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
189189        tp  =  getattr (args , "tp" , 1 )
190190        chpt_from  =  getattr (args , "chpt_from" , "hf" )
191191        sdp_backend_dict  =  {
192-             " math" torch .nn .attention .SDPBackend .MATH ,
193-             " flash_attention" torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
194-             " efficient_attention" torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
195-             " cudnn_attention" torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
192+             ' math' torch .nn .attention .SDPBackend .MATH ,
193+             ' flash_attention' torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
194+             ' efficient_attention' torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
195+             ' cudnn_attention' torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
196196        }
197197        attention_backend  =  sdp_backend_dict [args .attention_backend ]
198-         if  args .device  ==  "cpu"  and  (
199-             args .attention_backend  ==  "efficient_attention" 
200-             or  args .attention_backend  ==  "cudnn_attention" 
201-         ):
202-             print (
203-                 f"Warning: { args .attention_backend }  
204-             )
198+         if  args .device  ==  "cpu"  and  (args .attention_backend  ==  "efficient_attention" 
199+                                      or  args .attention_backend  ==  "cudnn_attention" ):
200+             print (f"Warning: { args .attention_backend }  )
205201            attention_backend  =  torch .nn .attention .SDPBackend .MATH 
206202        return  cls (
207203            checkpoint_dir = checkpoint_dir ,
@@ -250,29 +246,13 @@ class TokenizerArgs:
250246    is_sentencepiece : bool  =  False 
251247    is_tiktoken : bool  =  False 
252248    is_hf_tokenizer : bool  =  False 
253-     is_llama_3_2_mm : bool  =  False 
254249    t : Optional [Any ] =  None 
255250
256251    def  __post_init__ (self ):
257-         # special handling for llama-3.2-mm 
258-         if  "llama-3.2-11b-vision"  in  str (self .tokenizer_path ).lower ():
259-             try :
260-                 from  torchtune .models .llama3_2_vision  import  llama3_2_vision_transform 
261- 
262-                 self .t  =  llama3_2_vision_transform (path = str (self .tokenizer_path ))
263-                 self .is_llama_3_2_mm  =  True 
264-                 self .is_tiktoken  =  False 
265-                 self .is_sentencepiece  =  False 
266-                 self .is_hf_tokenizer  =  False 
267-                 return 
268-             except :
269-                 pass 
270- 
271252        try :
272253            from  tokenizer .tiktoken  import  Tokenizer  as  TiktokenTokenizer 
273254
274255            self .t  =  TiktokenTokenizer (model_path = str (self .tokenizer_path ))
275-             self .is_llama_3_2_mm  =  False 
276256            self .is_tiktoken  =  True 
277257            self .is_sentencepiece  =  False 
278258            self .is_hf_tokenizer  =  False 
@@ -284,7 +264,6 @@ def __post_init__(self):
284264            from  sentencepiece  import  SentencePieceProcessor 
285265
286266            self .t  =  SentencePieceProcessor (model_file = str (self .tokenizer_path ))
287-             self .is_llama_3_2_mm  =  False 
288267            self .is_tiktoken  =  False 
289268            self .is_sentencepiece  =  True 
290269            self .is_hf_tokenizer  =  False 
@@ -296,15 +275,13 @@ def __post_init__(self):
296275            from  tokenizer .hf_tokenizer  import  HFTokenizer 
297276
298277            self .t  =  HFTokenizer (str (self .tokenizer_path ))
299-             self .is_llama_3_2_mm  =  False 
300278            self .is_tiktoken  =  False 
301279            self .is_sentencepiece  =  False 
302280            self .is_hf_tokenizer  =  True 
303281            return 
304282        except :
305283            pass 
306284
307-         self .is_llama_3_2_mm  =  False 
308285        self .is_tiktoken  =  False 
309286        self .is_sentencepiece  =  False 
310287        self .is_hf_tokenizer  =  False 
@@ -319,32 +296,20 @@ def validate_model(
319296        if  model  is  None :
320297            return 
321298
322-         if  (
323-             sum (
324-                 [
325-                     self .is_tiktoken ,
326-                     self .is_hf_tokenizer ,
327-                     self .is_sentencepiece ,
328-                     self .is_llama_3_2_mm ,
329-                 ]
330-             )
331-             !=  1 
332-         ):
299+         if  sum ([self .is_tiktoken , self .is_hf_tokenizer , self .is_sentencepiece ]) !=  1 :
333300            raise  RuntimeError (f"no tokenizer was found at { self .tokenizer_path }  )
334301
335302        is_tiktoken  =  self .is_tiktoken 
336303        is_sentencepiece  =  self .is_sentencepiece 
337304        is_hf_tokenizer  =  self .is_hf_tokenizer 
338-         is_llama_3_2_mm  =  self .is_llama_3_2_mm 
339- 
340305        use_tiktoken  =  model .config .use_tiktoken 
341306        use_hf_tokenizer  =  model .config .use_hf_tokenizer 
342-         use_other_tokenizer  =  not  (use_tiktoken  or  use_hf_tokenizer )
307+         use_sentencepiece  =  not  (use_tiktoken  or  use_hf_tokenizer )
308+ 
343309        if  (
344-             (is_tiktoken  and  not  use_tiktoken )
345-             or  (is_hf_tokenizer  and  not  use_hf_tokenizer )
346-             or  (is_sentencepiece  and  not  use_other_tokenizer )
347-             or  (is_llama_3_2_mm  and  not  use_other_tokenizer )
310+             (is_tiktoken  and  not  use_tiktoken ) or 
311+             (is_hf_tokenizer  and  not  use_hf_tokenizer ) or 
312+             (is_sentencepiece  and  not  use_sentencepiece )
348313        ):
349314            raise  RuntimeError (
350315                "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}" .format (
@@ -542,7 +507,6 @@ def _load_model(builder_args: BuilderArgs) -> Model:
542507        # AOTI-compoiled model will load its own weights. 
543508        # Release weights here to avoid OOM 
544509        import  gc 
545- 
546510        if  hasattr (model , "model" ):
547511            model .model  =  None 
548512        gc .collect ()
@@ -600,7 +564,6 @@ def _initialize_model(
600564
601565            def  do_nothing (max_batch_size , max_seq_length ):
602566                pass 
603- 
604567            model .setup_caches  =  do_nothing 
605568
606569            model .forward  =  torch ._export .aot_load (
@@ -638,7 +601,6 @@ def do_nothing(max_batch_size, max_seq_length):
638601
639602            def  do_nothing (max_batch_size , max_seq_length ):
640603                pass 
641- 
642604            model .setup_caches  =  do_nothing 
643605
644606            model .forward  =  aoti_compiled_model 
@@ -713,9 +675,7 @@ def do_nothing(max_batch_size, max_seq_length):
713675        logger  =  SingletonLogger .get_logger ()
714676
715677        gpu_memory_monitor  =  GPUMemoryMonitor ("cuda" )
716-         logger .info (
717-             f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset }  
718-         )
678+         logger .info (f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset }  )
719679
720680        # Model-level config 
721681        if  builder_args .params_table :
@@ -726,16 +686,20 @@ def do_nothing(max_batch_size, max_seq_length):
726686        config  =  TransformerArgs .from_params (model_config .transformer_args ["text" ])
727687        logger .info (f"Transformer Config: { config }  )
728688
729-         # TODO: Move into head of file after solving circular import 
730-         from  torchchat .distributed .checkpoint_utils  import  load_model_weights 
689+         #TODO: Move into head of file after solving circular import 
690+         from  torchchat .distributed .checkpoint_utils  import  (
691+             load_model_weights ,
692+             )
731693
732694        # Validate pipeline degree 
733695        assert  config .n_layers  %  pp_degree  ==  0 
734696
735697        # Create device mesh 
736698        device_mesh  =  dist .init_device_mesh (
737-             "cuda" , (pp_degree , tp_degree ), mesh_dim_names = ("pp" , "tp" )
738-         )
699+             "cuda" ,
700+             (pp_degree , tp_degree ),
701+             mesh_dim_names = ("pp" , "tp" )
702+             )
739703        tp_mesh  =  device_mesh ["tp" ]
740704        pp_mesh  =  device_mesh ["pp" ]
741705        logger .info (f"Created device mesh: { device_mesh } \n { tp_mesh = } { pp_mesh = }  )
@@ -764,13 +728,7 @@ def do_nothing(max_batch_size, max_seq_length):
764728        # Load weights 
765729        logger .info (f"Loading weights for { pp_rank = } { device = }  )
766730        with  CUDATrackTime () as  timer :
767-             load_model_weights (
768-                 model ,
769-                 builder_args .distribution_path ,
770-                 device ,
771-                 config ,
772-                 builder_args .chpt_from ,
773-             )
731+             load_model_weights (model , builder_args .distribution_path , device , config , builder_args .chpt_from )
774732
775733        logger .info (
776734            f"{ color .green } { timer .get_time ()} { timer .unit } { rank } { color .reset }  
@@ -784,7 +742,7 @@ def do_nothing(max_batch_size, max_seq_length):
784742        # lanes. 
785743        # TODO: bump up the lane count 
786744        pipeline_lanes  =  1 
787-         seqlen_prefill   =   1024 
745+         seqlen_prefill = 1024 
788746        with  device :
789747            model .setup_caches (1 , seqlen_prefill , cache_lanes = pipeline_lanes )
790748
@@ -836,4 +794,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
836794        return  "TikToken" 
837795    if  tokenizers :
838796        return  "Tokenizers" 
839-     return  "SentencePiece" 
797+     return  "SentencePiece" 
0 commit comments