4343 "build_pipeline_schedule" ,
4444 "generate_llm_fqn_per_model_part" ,
4545 "pipeline_module_split" ,
46+ "split_module" ,
47+ "get_pp_rank_to_stage_indices_mapping" ,
4648]
4749
4850
@@ -64,10 +66,10 @@ def pipeline_llm(
6466 pp_mesh = parallel_dims .get_mesh ("pp" )
6567
6668 num_virtual_stages , num_layers , input_weight , output_weight = get_pipeline_metadata (
67- parallel_dims , job_config , model_args
69+ parallel_dims , parallelism , model_config
6870 )
6971
70- module_names_per_stage = job_config . parallelism .module_fqns_per_model_part
72+ module_names_per_stage = parallelism .module_fqns_per_model_part
7173 if module_names_per_stage is None :
7274 module_names_per_stage = generate_llm_fqn_per_model_part (
7375 num_virtual_stages , num_layers , input_weight , output_weight
@@ -78,7 +80,7 @@ def pipeline_llm(
7880 stages , model_parts = pipeline_module_split (
7981 model ,
8082 pp_mesh ,
81- job_config . parallelism .pipeline_parallel_schedule ,
83+ parallelism .pipeline_parallel_schedule ,
8284 device ,
8385 module_names_per_stage ,
8486 )
@@ -88,13 +90,27 @@ def pipeline_llm(
8890 # optimizer, and checkpointing
8991 for i , m in enumerate (model_parts ):
9092 # apply SPMD-style PT-D techniques
91- m = parallelize_fn (m , parallel_dims , job_config )
93+ m = parallelize_fn (
94+ m ,
95+ parallel_dims = parallel_dims ,
96+ training = training ,
97+ model_converters = model_converters ,
98+ parallelism = parallelism ,
99+ compile_config = compile_config ,
100+ ac_config = ac_config ,
101+ dump_folder = dump_folder ,
102+ )
92103 model_parts [i ] = m
93104 # NOTE: this is to update the model in the stage
94105 # in case the model is modified e.g. by torch.compile
95106 stages [i ].submod = m
96107
97- pp_schedule = build_pipeline_schedule (job_config , stages , loss_fn )
108+ pp_schedule = build_pipeline_schedule (
109+ parallelism = parallelism ,
110+ local_batch_size = training .local_batch_size ,
111+ stages = stages ,
112+ loss_fn = loss_fn ,
113+ )
98114
99115 # This is used in the train loop to determine whether to pass in the input_ids and labels
100116 has_first_stage = False
@@ -110,16 +126,16 @@ def pipeline_llm(
110126
111127def get_pipeline_metadata (
112128 parallel_dims : ParallelDims ,
113- job_config : JobConfig ,
114- model_args : BaseModelArgs ,
129+ parallelism : ParallelismConfig ,
130+ model_config : BaseModel . Config ,
115131) -> tuple [int , int , int , int ]:
116132 """
117133 Determine the number of virtual stages and the number of layers in the model.
118134
119135 Args:
120136 parallel_dims (ParallelDims): Parallel dimensions.
121- job_config (JobConfig ): Job configuration.
122- model_args (BaseModelArgs ): Model arguments .
137+ parallelism (ParallelismConfig ): Parallelism configuration.
138+ model_config (BaseModel.Config ): Model configuration .
123139
124140 Returns:
125141 tuple: A tuple containing the number of virtual stages, the number of layers in the model,
@@ -194,6 +210,8 @@ def build_pipeline_schedule(
194210 local_batch_size : int ,
195211 stages : list [PipelineStage ],
196212 loss_fn : Callable ,
213+ backward_requires_autograd : bool = True ,
214+ scale_grads : bool = False ,
197215) -> _PipelineSchedule :
198216 """Builds a pipeline schedule for the given job configuration and stages.
199217
@@ -242,7 +260,8 @@ def build_pipeline_schedule(
242260 stages if looped_schedule else stages [0 ],
243261 n_microbatches = n_microbatches ,
244262 loss_fn = loss_fn ,
245- scale_grads = False ,
263+ backward_requires_autograd = backward_requires_autograd ,
264+ scale_grads = scale_grads ,
246265 )
247266 logger .info (
248267 f"Using pipeline schedule { parallelism .pipeline_parallel_schedule } "
@@ -403,7 +422,9 @@ def split_module(
403422 modules_to_keep = set (module_names )
404423 for module_name , module_value in model .named_children ():
405424 # Handle layer-like structures (e.g., "layers.0", "layers.1")
406- if isinstance (module_value , (nn .ModuleDict , nn .ModuleList )):
425+ if isinstance (
426+ module_value , (nn .ModuleDict , nn .ModuleList , ModuleDict , ModuleList )
427+ ):
407428 layers_to_keep = {
408429 name .split ("." , 1 )[1 ]
409430 for name in modules_to_keep
@@ -419,7 +440,7 @@ def split_module(
419440 indices_to_keep = {
420441 int (idx ) for idx in layers_to_keep if idx .isdigit ()
421442 }
422- new_layers = nn . ModuleList (
443+ new_layers = ModuleList (
423444 [
424445 layer
425446 for i , layer in enumerate (module_value )
@@ -429,10 +450,10 @@ def split_module(
429450 setattr (model , module_name , new_layers )
430451 else :
431452 # No layers from this structure needed, set to empty structure
432- if isinstance (module_value , nn .ModuleDict ):
433- setattr (model , module_name , nn . ModuleDict ())
434- elif isinstance (module_value , nn .ModuleList ):
435- setattr (model , module_name , nn . ModuleList ())
453+ if isinstance (module_value , ( nn .ModuleDict , ModuleDict ) ):
454+ setattr (model , module_name , ModuleDict ())
455+ elif isinstance (module_value , ( nn .ModuleList , ModuleList ) ):
456+ setattr (model , module_name , ModuleList ())
436457 # Handle simple module attributes (e.g., "linear", "norm")
437458 elif module_name not in modules_to_keep :
438459 # Replace with None
0 commit comments