1414import torch
1515import torch ._dynamo .config
1616import torch ._inductor .config
17- import torch .nn as nn
17+ import torch .distributed as dist
1818
19- from torchchat .model import Model , ModelArgs , ModelType
19+ from torchchat .distributed .utils import (
20+ Color as color ,
21+ CUDATrackTime ,
22+ init_distributed ,
23+ GPUMemoryMonitor ,
24+ )
25+ from torchchat .distributed .logging_utils import SingletonLogger
2026
27+ from torchchat .model import Model , ModelArgs , ModelType , Transformer , TransformerArgs
2128from torchchat .model_config .model_config import resolve_model_config
2229from torchchat .utils .build_utils import (
2330 device_sync ,
2835from torchchat .utils .measure_time import measure_time
2936from torchchat .utils .quantize import quantize_model
3037
38+
3139from torchtune .models .convert_weights import meta_to_tune
3240
3341from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
@@ -56,6 +64,7 @@ class BuilderArgs:
5664 pp : int = 1
5765 tp : int = 1
5866 chpt_from : str = "hf"
67+ distribution_path : Optional [str ] = None
5968 is_chat_model : bool = False
6069 prefill_possible : bool = False
6170 dynamic_shapes : bool = False
@@ -107,6 +116,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
107116
108117 checkpoint_path = args .checkpoint_path
109118 params_table = args .params_table
119+ distribution_path = None
110120 if args .model : # Using a named, well-known model
111121 model_config = resolve_model_config (args .model )
112122
@@ -121,6 +131,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
121131 model_config .transformer_params_key or model_config .name .split ("/" )[- 1 ]
122132 )
123133
134+ distribution_path = model_config .distribution_path
135+
124136 dso_path = getattr (args , "dso_path" , None )
125137 pte_path = getattr (args , "pte_path" , None )
126138 aoti_package_path = getattr (args , "aoti_package_path" , None )
@@ -186,6 +198,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
186198 pp = pp ,
187199 tp = tp ,
188200 chpt_from = chpt_from ,
201+ distribution_path = distribution_path ,
189202 is_chat_model = is_chat_model ,
190203 dynamic_shapes = getattr (args , "dynamic_shapes" , False ),
191204 max_seq_length = getattr (args , "max_seq_length" , None ),
@@ -601,6 +614,100 @@ def do_nothing(max_batch_size, max_seq_length):
601614 model = PTEModel (config , builder_args .pte_path )
602615 except Exception :
603616 raise RuntimeError (f"Failed to load ET compiled { builder_args .pte_path } " )
617+ elif builder_args .distributed :
618+ pp_degree = builder_args .pp
619+ tp_degree = builder_args .tp
620+
621+ init_distributed ()
622+ rank = dist .get_rank ()
623+ torch .cuda .set_device (rank % torch .cuda .device_count ())
624+
625+ logger = SingletonLogger .get_logger ()
626+
627+ gpu_memory_monitor = GPUMemoryMonitor ("cuda" )
628+ logger .info (f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } " )
629+
630+ # Model-level config
631+ if builder_args .params_table :
632+ model_config = ModelArgs .from_table (builder_args .params_table )
633+ else :
634+ raise NotImplementedError ()
635+ # Transformer-level config
636+ config = TransformerArgs .from_params (model_config .transformer_args ["text" ])
637+ logger .info (f"Transformer Config: { config } " )
638+
639+ #TODO: Move into head of file after solving circular import
640+ from torchchat .distributed .checkpoint_utils import (
641+ load_model_weights ,
642+ )
643+
644+ # Validate pipeline degree
645+ assert config .n_layers % pp_degree == 0
646+
647+ # Create device mesh
648+ device_mesh = dist .init_device_mesh (
649+ "cuda" ,
650+ (pp_degree , tp_degree ),
651+ mesh_dim_names = ("pp" , "tp" )
652+ )
653+ tp_mesh = device_mesh ["tp" ]
654+ pp_mesh = device_mesh ["pp" ]
655+ logger .info (f"Created device mesh: { device_mesh } \n { tp_mesh = } , { pp_mesh = } " )
656+
657+ pp_rank = pp_mesh .get_local_rank ()
658+ logger .info (f"{ pp_degree = } , { tp_degree = } " )
659+
660+ # Assuming same number of GPUs per node
661+ device = torch .device (f"cuda:{ rank % torch .cuda .device_count ()} " )
662+
663+ # Fill in PP configs
664+ config .stage_idx = pp_rank
665+ config .n_stages = pp_degree
666+
667+ with torch .device ("meta" ):
668+ # TODO: we should create model instead of Transformer
669+ model = Transformer (config )
670+
671+ # Distribute model on TP mesh
672+ # (Surprisingly, this works even though model is on meta device and mesh is of
673+ # cuda devices)
674+ model .distribute (tp_mesh )
675+ if rank == 0 :
676+ logger .info (f"Model: { model } " )
677+
678+ # Load weights
679+ logger .info (f"Loading weights for { pp_rank = } on { device = } " )
680+ with CUDATrackTime () as timer :
681+ load_model_weights (model , builder_args .distribution_path , device , config , builder_args .chpt_from )
682+
683+ logger .info (
684+ f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
685+ )
686+
687+ # Setup KV caches (after model distribution)
688+ # The number of cache lanes is the same as the maximum number of
689+ # micro-batches that can be "in flight" in parallel -- imagine each
690+ # micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
691+ # When decoding is done for certain micro-batches, we can reuse the KV cache
692+ # lanes.
693+ # TODO: bump up the lane count
694+ pipeline_lanes = 1
695+ seqlen_prefill = 1024
696+ with device :
697+ model .setup_caches (1 , seqlen_prefill , cache_lanes = pipeline_lanes )
698+
699+ # info on stage size and params
700+ # stage_size = get_module_size(model)
701+ # stage_size_formatted = bytes_to_readable(stage_size)
702+ # stage_num_params = get_num_params(model)
703+ # logger.info(
704+ # f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}"
705+ # )
706+ model .eval ()
707+
708+ model .text_transformer_args = None
709+ model .config .model_type = model_config .model_type
710+ model .device_mesh = device_mesh
604711 else :
605712 with measure_time ("Time to load model: {time:.02f} seconds" ):
606713 model = _load_model (builder_args )
0 commit comments