2222 ScheduleDualPipeV ,
2323 ScheduleZBVZeroBubble ,
2424)
25-
2625from torchtitan .components .loss import LossFunction
2726from torchtitan .config import (
2827 ActivationCheckpointConfig ,
4039
4140__all__ = [
4241 "pipeline_llm" ,
42+ "get_pipeline_metadata" ,
4343 "build_pipeline_schedule" ,
4444 "generate_llm_fqn_per_model_part" ,
4545 "pipeline_module_split" ,
@@ -63,6 +63,68 @@ def pipeline_llm(
6363) -> tuple [_PipelineSchedule , list [nn .Module ], bool , bool ]:
6464 pp_mesh = parallel_dims .get_mesh ("pp" )
6565
66+ num_virtual_stages , num_layers , input_weight , output_weight = get_pipeline_metadata (
67+ parallel_dims , job_config , model_args
68+ )
69+
70+ module_names_per_stage = job_config .parallelism .module_fqns_per_model_part
71+ if module_names_per_stage is None :
72+ module_names_per_stage = generate_llm_fqn_per_model_part (
73+ num_virtual_stages , num_layers , input_weight , output_weight
74+ )
75+ for i , stage_ms in enumerate (module_names_per_stage ):
76+ logger .debug (f"Stage { i } : { stage_ms } " )
77+
78+ stages , model_parts = pipeline_module_split (
79+ model ,
80+ pp_mesh ,
81+ job_config .parallelism .pipeline_parallel_schedule ,
82+ device ,
83+ module_names_per_stage ,
84+ )
85+
86+ # For PP with looped schedules, each item in model_parts is one stage-model-chunk.
87+ # We need to iterate through model_parts to apply SPMD parallelisms, compilation,
88+ # optimizer, and checkpointing
89+ for i , m in enumerate (model_parts ):
90+ # apply SPMD-style PT-D techniques
91+ m = parallelize_fn (m , parallel_dims , job_config )
92+ model_parts [i ] = m
93+ # NOTE: this is to update the model in the stage
94+ # in case the model is modified e.g. by torch.compile
95+ stages [i ].submod = m
96+
97+ pp_schedule = build_pipeline_schedule (job_config , stages , loss_fn )
98+
99+ # This is used in the train loop to determine whether to pass in the input_ids and labels
100+ has_first_stage = False
101+ has_last_stage = False
102+ for stage in stages :
103+ if stage .is_first :
104+ has_first_stage = True
105+ if stage .is_last :
106+ has_last_stage = True
107+
108+ return pp_schedule , model_parts , has_first_stage , has_last_stage
109+
110+
111+ def get_pipeline_metadata (
112+ parallel_dims : ParallelDims ,
113+ job_config : JobConfig ,
114+ model_args : BaseModelArgs ,
115+ ) -> tuple [int , int , int , int ]:
116+ """
117+ Determine the number of virtual stages and the number of layers in the model.
118+
119+ Args:
120+ parallel_dims (ParallelDims): Parallel dimensions.
121+ job_config (JobConfig): Job configuration.
122+ model_args (BaseModelArgs): Model arguments.
123+
124+ Returns:
125+ tuple: A tuple containing the number of virtual stages, the number of layers in the model,
126+ the input weight, and the output weight.
127+ """
66128 # Determine the number of virtual stages based on schedule type
67129 schedule_class = get_schedule_class (parallelism .pipeline_parallel_schedule )
68130 is_single_stage_schedule = issubclass (schedule_class , PipelineScheduleSingle )
@@ -123,60 +185,7 @@ def pipeline_llm(
123185 # For single-stage schedules, default is 1 virtual stage per rank
124186 stages_per_rank = 1 if is_single_stage_schedule else 2
125187 num_virtual_stages = parallel_dims .pp * stages_per_rank
126-
127- module_names_per_stage = parallelism .module_fqns_per_model_part
128- if module_names_per_stage is None :
129- module_names_per_stage = generate_llm_fqn_per_model_part (
130- num_virtual_stages , num_layers , input_weight , output_weight
131- )
132- for i , stage_ms in enumerate (module_names_per_stage ):
133- logger .debug (f"Stage { i } : { stage_ms } " )
134-
135- stages , model_parts = pipeline_module_split (
136- model ,
137- pp_mesh ,
138- parallelism .pipeline_parallel_schedule ,
139- device ,
140- module_names_per_stage ,
141- )
142-
143- # For PP with looped schedules, each item in model_parts is one stage-model-chunk.
144- # We need to iterate through model_parts to apply SPMD parallelisms, compilation,
145- # optimizer, and checkpointing
146- for i , m in enumerate (model_parts ):
147- # apply SPMD-style PT-D techniques
148- m = parallelize_fn (
149- m ,
150- parallel_dims = parallel_dims ,
151- training = training ,
152- model_converters = model_converters ,
153- parallelism = parallelism ,
154- compile_config = compile_config ,
155- ac_config = ac_config ,
156- dump_folder = dump_folder ,
157- )
158- model_parts [i ] = m
159- # NOTE: this is to update the model in the stage
160- # in case the model is modified e.g. by torch.compile
161- stages [i ].submod = m
162-
163- pp_schedule = build_pipeline_schedule (
164- parallelism = parallelism ,
165- local_batch_size = training .local_batch_size ,
166- stages = stages ,
167- loss_fn = loss_fn ,
168- )
169-
170- # This is used in the train loop to determine whether to pass in the input_ids and labels
171- has_first_stage = False
172- has_last_stage = False
173- for stage in stages :
174- if stage .is_first :
175- has_first_stage = True
176- if stage .is_last :
177- has_last_stage = True
178-
179- return pp_schedule , model_parts , has_first_stage , has_last_stage
188+ return num_virtual_stages , num_layers , input_weight , output_weight
180189
181190
182191def build_pipeline_schedule (
@@ -371,6 +380,106 @@ def generate_llm_fqn_per_model_part(
371380 return module_names_per_stage
372381
373382
383+ def split_module (
384+ whole_model : nn .Module ,
385+ module_names : list [str ],
386+ ) -> nn .Module :
387+ """
388+ Splits a whole model into a module based on the specified module names.
389+
390+ Args:
391+ whole_model: The complete model to be split
392+ module_names: List of module names to include in the split
393+
394+ Returns:
395+ The split module
396+
397+ Example usage:
398+ module_names = ["tok_embeddings", "layers.0", "layers.1", "norm", "output"]
399+ split_module(whole_model, module_names)
400+ """
401+ model = copy .deepcopy (whole_model )
402+ # Create a set of modules to keep for faster lookup
403+ modules_to_keep = set (module_names )
404+ for module_name , module_value in model .named_children ():
405+ # Handle layer-like structures (e.g., "layers.0", "layers.1")
406+ if isinstance (module_value , (nn .ModuleDict , nn .ModuleList )):
407+ layers_to_keep = {
408+ name .split ("." , 1 )[1 ]
409+ for name in modules_to_keep
410+ if name .startswith (f"{ module_name } ." )
411+ }
412+ if layers_to_keep :
413+ # Keep only specified layers
414+ if isinstance (module_value , nn .ModuleDict ):
415+ for layer_name in list (module_value .keys ()):
416+ if layer_name not in layers_to_keep :
417+ del module_value [layer_name ]
418+ elif isinstance (module_value , nn .ModuleList ):
419+ indices_to_keep = {
420+ int (idx ) for idx in layers_to_keep if idx .isdigit ()
421+ }
422+ new_layers = nn .ModuleList (
423+ [
424+ layer
425+ for i , layer in enumerate (module_value )
426+ if i in indices_to_keep
427+ ]
428+ )
429+ setattr (model , module_name , new_layers )
430+ else :
431+ # No layers from this structure needed, set to empty structure
432+ if isinstance (module_value , nn .ModuleDict ):
433+ setattr (model , module_name , nn .ModuleDict ())
434+ elif isinstance (module_value , nn .ModuleList ):
435+ setattr (model , module_name , nn .ModuleList ())
436+ # Handle simple module attributes (e.g., "linear", "norm")
437+ elif module_name not in modules_to_keep :
438+ # Replace with None
439+ setattr (model , module_name , None )
440+ return model
441+
442+
443+ def get_pp_rank_to_stage_indices_mapping (
444+ pp_rank : int ,
445+ pp_degree ,
446+ pp_schedule : str ,
447+ num_stages : int ,
448+ ) -> tuple [int , ...]:
449+ """
450+ Returns a mapping from PP rank to stage indices for the given pipeline schedule.
451+
452+ Args:
453+ pp_rank: Pipeline parallel rank
454+ pp_degree: Number of pipeline parallel ranks
455+ pp_schedule: Name of pipeline parallelism schedule
456+ num_stages: Number of pipeline stages
457+
458+ Returns:
459+ Mapping from PP rank to stage indices
460+ """
461+ schedule_class = get_schedule_class (pp_schedule )
462+ style = (
463+ "v" if schedule_class in (ScheduleZBVZeroBubble , ScheduleDualPipeV ) else "loop"
464+ )
465+ assert (
466+ num_stages % pp_degree == 0
467+ ), f"num_stages { num_stages } must be evenly divisible by pp_degree { pp_degree } "
468+ stages_per_rank = num_stages // pp_degree
469+ if style == "loop" :
470+ return tuple (pp_rank + s * pp_degree for s in range (stages_per_rank ))
471+ elif style == "v" :
472+ assert (
473+ stages_per_rank == 2
474+ ), f"v schedules assume 2 stages per rank, got { stages_per_rank } "
475+ stage_v_pairs = list (
476+ zip (range (pp_degree ), range (num_stages - 1 , pp_degree - 1 , - 1 ))
477+ )
478+ return tuple (stage_v_pairs [pp_rank ])
479+ else :
480+ raise ValueError (f"Unknown style { style } " )
481+
482+
374483def pipeline_module_split (
375484 whole_model : nn .Module ,
376485 pp_mesh : DeviceMesh ,
@@ -412,97 +521,21 @@ def pipeline_module_split(
412521 """
413522 pp_rank = pp_mesh .get_local_rank ()
414523 pp_degree = pp_mesh .size ()
415-
416- def _build_stage_from_modules (
417- stage_idx : int , module_names : list [str ], num_stages : int
418- ) -> tuple [PipelineStage , nn .Module ]:
419- model = copy .deepcopy (whole_model )
420-
421- # Create a set of modules to keep for faster lookup
422- modules_to_keep = set (module_names )
423- for module_name , module_value in model .named_children ():
424- # Handle layer-like structures (e.g., "layers.0", "layers.1")
425- if isinstance (module_value , (nn .ModuleDict , nn .ModuleList )):
426- layers_to_keep = {
427- name .split ("." , 1 )[1 ]
428- for name in modules_to_keep
429- if name .startswith (f"{ module_name } ." )
430- }
431- if layers_to_keep :
432- # Keep only specified layers
433- if isinstance (module_value , nn .ModuleDict ):
434- for layer_name in list (module_value .keys ()):
435- if layer_name not in layers_to_keep :
436- del module_value [layer_name ]
437- elif isinstance (module_value , nn .ModuleList ):
438- indices_to_keep = {
439- int (idx ) for idx in layers_to_keep if idx .isdigit ()
440- }
441- new_layers = ModuleList (
442- [
443- layer
444- for i , layer in enumerate (module_value )
445- if i in indices_to_keep
446- ]
447- )
448- setattr (model , module_name , new_layers )
449- else :
450- # No layers from this structure needed, set to empty structure
451- if isinstance (module_value , nn .ModuleDict ):
452- setattr (model , module_name , ModuleDict ())
453- elif isinstance (module_value , nn .ModuleList ):
454- setattr (model , module_name , ModuleList ())
455- # Handle simple module attributes (e.g., "linear", "norm")
456- elif module_name not in modules_to_keep :
457- # Replace with None
458- setattr (model , module_name , None )
459-
460- stage = PipelineStage (
461- model ,
462- stage_idx ,
463- num_stages ,
464- device ,
465- group = pp_mesh .get_group ("pp" ),
466- )
467- return stage , model
468-
469524 num_stages = len (module_names_per_stage )
470525 stages = []
471526 models = []
472-
473- schedule_class = get_schedule_class (pp_schedule )
474- style = (
475- "v" if schedule_class in (ScheduleZBVZeroBubble , ScheduleDualPipeV ) else "loop"
527+ pp_rank_to_stage_indices = get_pp_rank_to_stage_indices_mapping (
528+ pp_rank , pp_degree , pp_schedule , num_stages
476529 )
477-
478- def _get_stage_indices () -> tuple [int , ...]:
479- """
480- Compute the stage ids for the stages that will run on this pp rank
481- for either a looped or V style schedule
482- """
483- assert (
484- num_stages % pp_degree == 0
485- ), f"num_stages { num_stages } must be evenly divisible by pp_degree { pp_degree } "
486- stages_per_rank = num_stages // pp_degree
487- if style == "loop" :
488- return tuple (pp_rank + s * pp_degree for s in range (stages_per_rank ))
489- elif style == "v" :
490- assert (
491- stages_per_rank == 2
492- ), f"v schedules assume 2 stages per rank, got { stages_per_rank } "
493- stage_v_pairs = list (
494- zip (range (pp_degree ), range (num_stages - 1 , pp_degree - 1 , - 1 ))
495- )
496- return stage_v_pairs [pp_rank ]
497- else :
498- raise ValueError (f"Unknown style { style } " )
499-
500- for stage_idx in _get_stage_indices ():
530+ for stage_idx in pp_rank_to_stage_indices :
501531 module_names = module_names_per_stage [stage_idx ]
502- stage , model_chunk = _build_stage_from_modules (
532+ model_chunk = split_module (whole_model , module_names )
533+ stage = PipelineStage (
534+ model_chunk ,
503535 stage_idx ,
504- module_names ,
505536 num_stages ,
537+ device ,
538+ group = pp_mesh .get_group ("pp" ),
506539 )
507540 logger .info (
508541 f"PP rank { pp_rank } is building stage_idx { stage_idx } "
0 commit comments