17
17
import torch ._inductor .config
18
18
import torch .distributed as dist
19
19
20
- from torchchat .distributed .utils import (
20
+ from torchchat .distributed .logging_utils import SingletonLogger
21
+
22
+ from torchchat .distributed .utils import (
21
23
Color as color ,
22
24
CUDATrackTime ,
23
- init_distributed ,
24
25
GPUMemoryMonitor ,
26
+ init_distributed ,
25
27
)
26
- from torchchat .distributed .logging_utils import SingletonLogger
27
28
28
29
from torchchat .model import Model , ModelArgs , ModelType , Transformer , TransformerArgs
29
30
from torchchat .model_config .model_config import resolve_model_config
37
38
from torchchat .utils .quantize import quantize_model
38
39
39
40
40
- from torchtune .models .convert_weights import meta_to_tune
41
-
42
- from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
43
-
44
- from torchtune .models .llama3_2_vision ._convert_weights import llama3_vision_meta_to_tune
45
-
46
- from torchtune .training import set_default_dtype
47
-
48
-
49
41
@dataclass
50
42
class BuilderArgs :
51
43
checkpoint_path : Optional [Union [Path , str ]] = None
@@ -188,15 +180,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
188
180
tp = getattr (args , "tp" , 1 )
189
181
chpt_from = getattr (args , "chpt_from" , "hf" )
190
182
sdp_backend_dict = {
191
- ' math' : torch .nn .attention .SDPBackend .MATH ,
192
- ' flash_attention' : torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
193
- ' efficient_attention' : torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
194
- ' cudnn_attention' : torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
183
+ " math" : torch .nn .attention .SDPBackend .MATH ,
184
+ " flash_attention" : torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
185
+ " efficient_attention" : torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
186
+ " cudnn_attention" : torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
195
187
}
196
188
attention_backend = sdp_backend_dict [args .attention_backend ]
197
- if args .device == "cpu" and (args .attention_backend == "efficient_attention"
198
- or args .attention_backend == "cudnn_attention" ):
199
- print (f"Warning: { args .attention_backend } is not supported on CPU. Using math instead." )
189
+ if args .device == "cpu" and (
190
+ args .attention_backend == "efficient_attention"
191
+ or args .attention_backend == "cudnn_attention"
192
+ ):
193
+ print (
194
+ f"Warning: { args .attention_backend } is not supported on CPU. Using math instead."
195
+ )
200
196
attention_backend = torch .nn .attention .SDPBackend .MATH
201
197
return cls (
202
198
checkpoint_dir = checkpoint_dir ,
@@ -238,12 +234,14 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
238
234
speculative_builder_args .pte_path = None
239
235
return speculative_builder_args
240
236
237
+
241
238
class TokenizerType (Enum ):
242
239
NONE = 0
243
240
TIKTOKEN = 1
244
241
SENTENCEPIECE = 2
245
242
HF_TOKENIZER = 3
246
243
244
+
247
245
@dataclass
248
246
class TokenizerArgs :
249
247
tokenizer_path : Optional [Union [Path , str ]] = None
@@ -307,9 +305,9 @@ def validate_model(
307
305
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer )
308
306
309
307
if (
310
- (is_tiktoken and not use_tiktoken ) or
311
- (is_hf_tokenizer and not use_hf_tokenizer ) or
312
- (is_sentencepiece and not use_sentencepiece )
308
+ (is_tiktoken and not use_tiktoken )
309
+ or (is_hf_tokenizer and not use_hf_tokenizer )
310
+ or (is_sentencepiece and not use_sentencepiece )
313
311
):
314
312
raise RuntimeError (
315
313
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}" .format (
@@ -417,6 +415,7 @@ def _load_model_gguf(builder_args: BuilderArgs) -> Model:
417
415
418
416
def _load_checkpoint (builder_args : BuilderArgs ):
419
417
if builder_args .params_table and builder_args .params_table .endswith ("Tune" ):
418
+ from torchtune .models .convert_weights import meta_to_tune
420
419
print ("Loading Tune checkpoint" )
421
420
meta_checkpoint = torch .load (
422
421
str (builder_args .checkpoint_path ), mmap = True , weights_only = True
@@ -469,9 +468,15 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
469
468
checkpoint = checkpoint ["model" ]
470
469
471
470
if model .config .model_type == ModelType .Flamingo :
471
+ from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
472
+ from torchtune .models .llama3_2_vision ._convert_weights import (
473
+ llama3_vision_meta_to_tune ,
474
+ )
475
+ from torchtune .training import set_default_dtype
472
476
# TODO: Refactor this. For now, overwrite the model with model loaded from params_path
473
- with set_default_dtype (builder_args .precision ), torch .device (
474
- builder_args .device
477
+ with (
478
+ set_default_dtype (builder_args .precision ),
479
+ torch .device (builder_args .device ),
475
480
):
476
481
# It doubles the model size the memory, with redundancies of the initialized weights.
477
482
# model = Model.from_params(builder_args.params_path)
@@ -507,6 +512,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
507
512
# AOTI-compoiled model will load its own weights.
508
513
# Release weights here to avoid OOM
509
514
import gc
515
+
510
516
if hasattr (model , "model" ):
511
517
model .model = None
512
518
gc .collect ()
@@ -564,6 +570,7 @@ def _initialize_model(
564
570
565
571
def do_nothing (max_batch_size , max_seq_length ):
566
572
pass
573
+
567
574
model .setup_caches = do_nothing
568
575
569
576
model .forward = torch ._export .aot_load (
@@ -601,6 +608,7 @@ def do_nothing(max_batch_size, max_seq_length):
601
608
602
609
def do_nothing (max_batch_size , max_seq_length ):
603
610
pass
611
+
604
612
model .setup_caches = do_nothing
605
613
606
614
model .forward = aoti_compiled_model
@@ -652,12 +660,15 @@ def do_nothing(max_batch_size, max_seq_length):
652
660
try :
653
661
model = torch .load (builder_args .snapshot_path , weights_only = False )
654
662
except Exception :
655
- raise RuntimeError (f"Failed to load torchchat snapshot { builder_args .snapshot_path } " )
663
+ raise RuntimeError (
664
+ f"Failed to load torchchat snapshot { builder_args .snapshot_path } "
665
+ )
656
666
# _active_backend() does not allow DSO & AOTI to be true.
657
667
# Choose either.
658
668
from torchchat .utils .build_utils import set_backend
659
- set_backend (dso = True , pte = False , aoti_package = False )
660
- if (model .config != config ):
669
+
670
+ set_backend (dso = True , pte = False , aoti_package = False )
671
+ if model .config != config :
661
672
raise RuntimeError ("loaded model architecture mismatch" )
662
673
##
663
674
## import all libraries with custom kernels ans custom operators
@@ -675,7 +686,9 @@ def do_nothing(max_batch_size, max_seq_length):
675
686
logger = SingletonLogger .get_logger ()
676
687
677
688
gpu_memory_monitor = GPUMemoryMonitor ("cuda" )
678
- logger .info (f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } " )
689
+ logger .info (
690
+ f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } "
691
+ )
679
692
680
693
# Model-level config
681
694
if builder_args .params_table :
@@ -686,20 +699,16 @@ def do_nothing(max_batch_size, max_seq_length):
686
699
config = TransformerArgs .from_params (model_config .transformer_args ["text" ])
687
700
logger .info (f"Transformer Config: { config } " )
688
701
689
- #TODO: Move into head of file after solving circular import
690
- from torchchat .distributed .checkpoint_utils import (
691
- load_model_weights ,
692
- )
702
+ # TODO: Move into head of file after solving circular import
703
+ from torchchat .distributed .checkpoint_utils import load_model_weights
693
704
694
705
# Validate pipeline degree
695
706
assert config .n_layers % pp_degree == 0
696
707
697
708
# Create device mesh
698
709
device_mesh = dist .init_device_mesh (
699
- "cuda" ,
700
- (pp_degree , tp_degree ),
701
- mesh_dim_names = ("pp" , "tp" )
702
- )
710
+ "cuda" , (pp_degree , tp_degree ), mesh_dim_names = ("pp" , "tp" )
711
+ )
703
712
tp_mesh = device_mesh ["tp" ]
704
713
pp_mesh = device_mesh ["pp" ]
705
714
logger .info (f"Created device mesh: { device_mesh } \n { tp_mesh = } , { pp_mesh = } " )
@@ -728,7 +737,13 @@ def do_nothing(max_batch_size, max_seq_length):
728
737
# Load weights
729
738
logger .info (f"Loading weights for { pp_rank = } on { device = } " )
730
739
with CUDATrackTime () as timer :
731
- load_model_weights (model , builder_args .distribution_path , device , config , builder_args .chpt_from )
740
+ load_model_weights (
741
+ model ,
742
+ builder_args .distribution_path ,
743
+ device ,
744
+ config ,
745
+ builder_args .chpt_from ,
746
+ )
732
747
733
748
logger .info (
734
749
f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
@@ -742,7 +757,7 @@ def do_nothing(max_batch_size, max_seq_length):
742
757
# lanes.
743
758
# TODO: bump up the lane count
744
759
pipeline_lanes = 1
745
- seqlen_prefill = 1024
760
+ seqlen_prefill = 1024
746
761
with device :
747
762
model .setup_caches (1 , seqlen_prefill , cache_lanes = pipeline_lanes )
748
763
0 commit comments