44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66import copy
7-
87import math
98import os
109from typing import Callable
1312import torch .nn as nn
1413from torch .distributed .device_mesh import DeviceMesh
1514from torch .distributed .pipelining import PipelineStage
16-
1715from torch .distributed .pipelining .schedules import (
1816 _PipelineSchedule ,
1917 _PipelineScheduleRuntime ,
2422 ScheduleDualPipeV ,
2523 ScheduleZBVZeroBubble ,
2624)
27-
2825from torchtitan .components .loss import LossFunction , rescale_accumulated_loss
2926from torchtitan .config import JobConfig
3027from torchtitan .distributed import ParallelDims
3431
3532__all__ = [
3633 "pipeline_llm" ,
34+ "get_pipeline_metadata" ,
3735 "build_pipeline_schedule" ,
3836 "generate_llm_fqn_per_model_part" ,
3937 "pipeline_module_split" ,
@@ -51,6 +49,68 @@ def pipeline_llm(
5149) -> tuple [_PipelineSchedule , list [nn .Module ], bool , bool ]:
5250 pp_mesh = parallel_dims .get_mesh ("pp" )
5351
52+ num_virtual_stages , num_layers , input_weight , output_weight = get_pipeline_metadata (
53+ parallel_dims , job_config , model_args
54+ )
55+
56+ module_names_per_stage = job_config .parallelism .module_fqns_per_model_part
57+ if module_names_per_stage is None :
58+ module_names_per_stage = generate_llm_fqn_per_model_part (
59+ num_virtual_stages , num_layers , input_weight , output_weight
60+ )
61+ for i , stage_ms in enumerate (module_names_per_stage ):
62+ logger .debug (f"Stage { i } : { stage_ms } " )
63+
64+ stages , model_parts = pipeline_module_split (
65+ model ,
66+ pp_mesh ,
67+ job_config .parallelism .pipeline_parallel_schedule ,
68+ device ,
69+ module_names_per_stage ,
70+ )
71+
72+ # For PP with looped schedules, each item in model_parts is one stage-model-chunk.
73+ # We need to iterate through model_parts to apply SPMD parallelisms, compilation,
74+ # optimizer, and checkpointing
75+ for i , m in enumerate (model_parts ):
76+ # apply SPMD-style PT-D techniques
77+ m = parallelize_fn (m , parallel_dims , job_config )
78+ model_parts [i ] = m
79+ # NOTE: this is to update the model in the stage
80+ # in case the model is modified e.g. by torch.compile
81+ stages [i ].submod = m
82+
83+ pp_schedule = build_pipeline_schedule (job_config , stages , loss_fn )
84+
85+ # This is used in the train loop to determine whether to pass in the input_ids and labels
86+ has_first_stage = False
87+ has_last_stage = False
88+ for stage in stages :
89+ if stage .is_first :
90+ has_first_stage = True
91+ if stage .is_last :
92+ has_last_stage = True
93+
94+ return pp_schedule , model_parts , has_first_stage , has_last_stage
95+
96+
97+ def get_pipeline_metadata (
98+ parallel_dims : ParallelDims ,
99+ job_config : JobConfig ,
100+ model_args : BaseModelArgs ,
101+ ) -> tuple [int , int , int , int ]:
102+ """
103+ Determine the number of virtual stages and the number of layers in the model.
104+
105+ Args:
106+ parallel_dims (ParallelDims): Parallel dimensions.
107+ job_config (JobConfig): Job configuration.
108+ model_args (BaseModelArgs): Model arguments.
109+
110+ Returns:
111+ tuple: A tuple containing the number of virtual stages, the number of layers in the model,
112+ the input weight, and the output weight.
113+ """
54114 # Determine the number of virtual stages based on schedule type
55115 schedule_class = get_schedule_class (
56116 job_config .parallelism .pipeline_parallel_schedule
@@ -113,46 +173,7 @@ def pipeline_llm(
113173 # For single-stage schedules, default is 1 virtual stage per rank
114174 stages_per_rank = 1 if is_single_stage_schedule else 2
115175 num_virtual_stages = parallel_dims .pp * stages_per_rank
116-
117- module_names_per_stage = job_config .parallelism .module_fqns_per_model_part
118- if module_names_per_stage is None :
119- module_names_per_stage = generate_llm_fqn_per_model_part (
120- num_virtual_stages , num_layers , input_weight , output_weight
121- )
122- for i , stage_ms in enumerate (module_names_per_stage ):
123- logger .debug (f"Stage { i } : { stage_ms } " )
124-
125- stages , model_parts = pipeline_module_split (
126- model ,
127- pp_mesh ,
128- job_config .parallelism .pipeline_parallel_schedule ,
129- device ,
130- module_names_per_stage ,
131- )
132-
133- # For PP with looped schedules, each item in model_parts is one stage-model-chunk.
134- # We need to iterate through model_parts to apply SPMD parallelisms, compilation,
135- # optimizer, and checkpointing
136- for i , m in enumerate (model_parts ):
137- # apply SPMD-style PT-D techniques
138- m = parallelize_fn (m , parallel_dims , job_config )
139- model_parts [i ] = m
140- # NOTE: this is to update the model in the stage
141- # in case the model is modified e.g. by torch.compile
142- stages [i ].submod = m
143-
144- pp_schedule = build_pipeline_schedule (job_config , stages , loss_fn )
145-
146- # This is used in the train loop to determine whether to pass in the input_ids and labels
147- has_first_stage = False
148- has_last_stage = False
149- for stage in stages :
150- if stage .is_first :
151- has_first_stage = True
152- if stage .is_last :
153- has_last_stage = True
154-
155- return pp_schedule , model_parts , has_first_stage , has_last_stage
176+ return num_virtual_stages , num_layers , input_weight , output_weight
156177
157178
158179def build_pipeline_schedule (
@@ -344,6 +365,106 @@ def generate_llm_fqn_per_model_part(
344365 return module_names_per_stage
345366
346367
368+ def split_module (
369+ whole_model : nn .Module ,
370+ module_names : list [str ],
371+ ) -> nn .Module :
372+ """
373+ Splits a whole model into a module based on the specified module names.
374+
375+ Args:
376+ whole_model: The complete model to be split
377+ module_names: List of module names to include in the split
378+
379+ Returns:
380+ The split module
381+
382+ Example usage:
383+ module_names = ["tok_embeddings", "layers.0", "layers.1", "norm", "output"]
384+ split_module(whole_model, module_names)
385+ """
386+ model = copy .deepcopy (whole_model )
387+ # Create a set of modules to keep for faster lookup
388+ modules_to_keep = set (module_names )
389+ for module_name , module_value in model .named_children ():
390+ # Handle layer-like structures (e.g., "layers.0", "layers.1")
391+ if isinstance (module_value , (nn .ModuleDict , nn .ModuleList )):
392+ layers_to_keep = {
393+ name .split ("." , 1 )[1 ]
394+ for name in modules_to_keep
395+ if name .startswith (f"{ module_name } ." )
396+ }
397+ if layers_to_keep :
398+ # Keep only specified layers
399+ if isinstance (module_value , nn .ModuleDict ):
400+ for layer_name in list (module_value .keys ()):
401+ if layer_name not in layers_to_keep :
402+ del module_value [layer_name ]
403+ elif isinstance (module_value , nn .ModuleList ):
404+ indices_to_keep = {
405+ int (idx ) for idx in layers_to_keep if idx .isdigit ()
406+ }
407+ new_layers = nn .ModuleList (
408+ [
409+ layer
410+ for i , layer in enumerate (module_value )
411+ if i in indices_to_keep
412+ ]
413+ )
414+ setattr (model , module_name , new_layers )
415+ else :
416+ # No layers from this structure needed, set to empty structure
417+ if isinstance (module_value , nn .ModuleDict ):
418+ setattr (model , module_name , nn .ModuleDict ())
419+ elif isinstance (module_value , nn .ModuleList ):
420+ setattr (model , module_name , nn .ModuleList ())
421+ # Handle simple module attributes (e.g., "linear", "norm")
422+ elif module_name not in modules_to_keep :
423+ # Replace with None
424+ setattr (model , module_name , None )
425+ return model
426+
427+
428+ def get_pp_rank_to_stage_indices_mapping (
429+ pp_rank : int ,
430+ pp_degree ,
431+ pp_schedule : str ,
432+ num_stages : int ,
433+ ) -> tuple [int , ...]:
434+ """
435+ Returns a mapping from PP rank to stage indices for the given pipeline schedule.
436+
437+ Args:
438+ pp_rank: Pipeline parallel rank
439+ pp_degree: Number of pipeline parallel ranks
440+ pp_schedule: Name of pipeline parallelism schedule
441+ num_stages: Number of pipeline stages
442+
443+ Returns:
444+ Mapping from PP rank to stage indices
445+ """
446+ schedule_class = get_schedule_class (pp_schedule )
447+ style = (
448+ "v" if schedule_class in (ScheduleZBVZeroBubble , ScheduleDualPipeV ) else "loop"
449+ )
450+ assert (
451+ num_stages % pp_degree == 0
452+ ), f"num_stages { num_stages } must be evenly divisible by pp_degree { pp_degree } "
453+ stages_per_rank = num_stages // pp_degree
454+ if style == "loop" :
455+ return tuple (pp_rank + s * pp_degree for s in range (stages_per_rank ))
456+ elif style == "v" :
457+ assert (
458+ stages_per_rank == 2
459+ ), f"v schedules assume 2 stages per rank, got { stages_per_rank } "
460+ stage_v_pairs = list (
461+ zip (range (pp_degree ), range (num_stages - 1 , pp_degree - 1 , - 1 ))
462+ )
463+ return tuple (stage_v_pairs [pp_rank ])
464+ else :
465+ raise ValueError (f"Unknown style { style } " )
466+
467+
347468def pipeline_module_split (
348469 whole_model : nn .Module ,
349470 pp_mesh : DeviceMesh ,
@@ -385,97 +506,21 @@ def pipeline_module_split(
385506 """
386507 pp_rank = pp_mesh .get_local_rank ()
387508 pp_degree = pp_mesh .size ()
388-
389- def _build_stage_from_modules (
390- stage_idx : int , module_names : list [str ], num_stages : int
391- ) -> tuple [PipelineStage , nn .Module ]:
392- model = copy .deepcopy (whole_model )
393-
394- # Create a set of modules to keep for faster lookup
395- modules_to_keep = set (module_names )
396- for module_name , module_value in model .named_children ():
397- # Handle layer-like structures (e.g., "layers.0", "layers.1")
398- if isinstance (module_value , (nn .ModuleDict , nn .ModuleList )):
399- layers_to_keep = {
400- name .split ("." , 1 )[1 ]
401- for name in modules_to_keep
402- if name .startswith (f"{ module_name } ." )
403- }
404- if layers_to_keep :
405- # Keep only specified layers
406- if isinstance (module_value , nn .ModuleDict ):
407- for layer_name in list (module_value .keys ()):
408- if layer_name not in layers_to_keep :
409- del module_value [layer_name ]
410- elif isinstance (module_value , nn .ModuleList ):
411- indices_to_keep = {
412- int (idx ) for idx in layers_to_keep if idx .isdigit ()
413- }
414- new_layers = nn .ModuleList (
415- [
416- layer
417- for i , layer in enumerate (module_value )
418- if i in indices_to_keep
419- ]
420- )
421- setattr (model , module_name , new_layers )
422- else :
423- # No layers from this structure needed, set to empty structure
424- if isinstance (module_value , nn .ModuleDict ):
425- setattr (model , module_name , nn .ModuleDict ())
426- elif isinstance (module_value , nn .ModuleList ):
427- setattr (model , module_name , nn .ModuleList ())
428- # Handle simple module attributes (e.g., "linear", "norm")
429- elif module_name not in modules_to_keep :
430- # Replace with None
431- setattr (model , module_name , None )
432-
433- stage = PipelineStage (
434- model ,
435- stage_idx ,
436- num_stages ,
437- device ,
438- group = pp_mesh .get_group ("pp" ),
439- )
440- return stage , model
441-
442509 num_stages = len (module_names_per_stage )
443510 stages = []
444511 models = []
445-
446- schedule_class = get_schedule_class (pp_schedule )
447- style = (
448- "v" if schedule_class in (ScheduleZBVZeroBubble , ScheduleDualPipeV ) else "loop"
512+ pp_rank_to_stage_indices = get_pp_rank_to_stage_indices_mapping (
513+ pp_rank , pp_degree , pp_schedule , num_stages
449514 )
450-
451- def _get_stage_indices () -> tuple [int , ...]:
452- """
453- Compute the stage ids for the stages that will run on this pp rank
454- for either a looped or V style schedule
455- """
456- assert (
457- num_stages % pp_degree == 0
458- ), f"num_stages { num_stages } must be evenly divisible by pp_degree { pp_degree } "
459- stages_per_rank = num_stages // pp_degree
460- if style == "loop" :
461- return tuple (pp_rank + s * pp_degree for s in range (stages_per_rank ))
462- elif style == "v" :
463- assert (
464- stages_per_rank == 2
465- ), f"v schedules assume 2 stages per rank, got { stages_per_rank } "
466- stage_v_pairs = list (
467- zip (range (pp_degree ), range (num_stages - 1 , pp_degree - 1 , - 1 ))
468- )
469- return stage_v_pairs [pp_rank ]
470- else :
471- raise ValueError (f"Unknown style { style } " )
472-
473- for stage_idx in _get_stage_indices ():
515+ for stage_idx in pp_rank_to_stage_indices :
474516 module_names = module_names_per_stage [stage_idx ]
475- stage , model_chunk = _build_stage_from_modules (
517+ model_chunk = split_module (whole_model , module_names )
518+ stage = PipelineStage (
519+ model_chunk ,
476520 stage_idx ,
477- module_names ,
478521 num_stages ,
522+ device ,
523+ group = pp_mesh .get_group ("pp" ),
479524 )
480525 logger .info (
481526 f"PP rank { pp_rank } is building stage_idx { stage_idx } "
0 commit comments