Skip to content

Commit ae1ab48

Browse files
Extract get_pipeline_metadata and remove stale deepseek_v3 experiment
stack-info: PR: #2271, branch: sanketpurandare/stack/2
1 parent d84e83d commit ae1ab48

File tree

4 files changed

+172
-679
lines changed

4 files changed

+172
-679
lines changed

torchtitan/distributed/pipeline_parallel.py

Lines changed: 172 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
ScheduleDualPipeV,
2323
ScheduleZBVZeroBubble,
2424
)
25-
2625
from torchtitan.components.loss import LossFunction
2726
from torchtitan.config import (
2827
ActivationCheckpointConfig,
@@ -40,6 +39,7 @@
4039

4140
__all__ = [
4241
"pipeline_llm",
42+
"get_pipeline_metadata",
4343
"build_pipeline_schedule",
4444
"generate_llm_fqn_per_model_part",
4545
"pipeline_module_split",
@@ -63,6 +63,68 @@ def pipeline_llm(
6363
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
6464
pp_mesh = parallel_dims.get_mesh("pp")
6565

66+
num_virtual_stages, num_layers, input_weight, output_weight = get_pipeline_metadata(
67+
parallel_dims, job_config, model_args
68+
)
69+
70+
module_names_per_stage = job_config.parallelism.module_fqns_per_model_part
71+
if module_names_per_stage is None:
72+
module_names_per_stage = generate_llm_fqn_per_model_part(
73+
num_virtual_stages, num_layers, input_weight, output_weight
74+
)
75+
for i, stage_ms in enumerate(module_names_per_stage):
76+
logger.debug(f"Stage {i}: {stage_ms}")
77+
78+
stages, model_parts = pipeline_module_split(
79+
model,
80+
pp_mesh,
81+
job_config.parallelism.pipeline_parallel_schedule,
82+
device,
83+
module_names_per_stage,
84+
)
85+
86+
# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
87+
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
88+
# optimizer, and checkpointing
89+
for i, m in enumerate(model_parts):
90+
# apply SPMD-style PT-D techniques
91+
m = parallelize_fn(m, parallel_dims, job_config)
92+
model_parts[i] = m
93+
# NOTE: this is to update the model in the stage
94+
# in case the model is modified e.g. by torch.compile
95+
stages[i].submod = m
96+
97+
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
98+
99+
# This is used in the train loop to determine whether to pass in the input_ids and labels
100+
has_first_stage = False
101+
has_last_stage = False
102+
for stage in stages:
103+
if stage.is_first:
104+
has_first_stage = True
105+
if stage.is_last:
106+
has_last_stage = True
107+
108+
return pp_schedule, model_parts, has_first_stage, has_last_stage
109+
110+
111+
def get_pipeline_metadata(
112+
parallel_dims: ParallelDims,
113+
job_config: JobConfig,
114+
model_args: BaseModelArgs,
115+
) -> tuple[int, int, int, int]:
116+
"""
117+
Determine the number of virtual stages and the number of layers in the model.
118+
119+
Args:
120+
parallel_dims (ParallelDims): Parallel dimensions.
121+
job_config (JobConfig): Job configuration.
122+
model_args (BaseModelArgs): Model arguments.
123+
124+
Returns:
125+
tuple: A tuple containing the number of virtual stages, the number of layers in the model,
126+
the input weight, and the output weight.
127+
"""
66128
# Determine the number of virtual stages based on schedule type
67129
schedule_class = get_schedule_class(parallelism.pipeline_parallel_schedule)
68130
is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle)
@@ -123,60 +185,7 @@ def pipeline_llm(
123185
# For single-stage schedules, default is 1 virtual stage per rank
124186
stages_per_rank = 1 if is_single_stage_schedule else 2
125187
num_virtual_stages = parallel_dims.pp * stages_per_rank
126-
127-
module_names_per_stage = parallelism.module_fqns_per_model_part
128-
if module_names_per_stage is None:
129-
module_names_per_stage = generate_llm_fqn_per_model_part(
130-
num_virtual_stages, num_layers, input_weight, output_weight
131-
)
132-
for i, stage_ms in enumerate(module_names_per_stage):
133-
logger.debug(f"Stage {i}: {stage_ms}")
134-
135-
stages, model_parts = pipeline_module_split(
136-
model,
137-
pp_mesh,
138-
parallelism.pipeline_parallel_schedule,
139-
device,
140-
module_names_per_stage,
141-
)
142-
143-
# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
144-
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
145-
# optimizer, and checkpointing
146-
for i, m in enumerate(model_parts):
147-
# apply SPMD-style PT-D techniques
148-
m = parallelize_fn(
149-
m,
150-
parallel_dims=parallel_dims,
151-
training=training,
152-
model_converters=model_converters,
153-
parallelism=parallelism,
154-
compile_config=compile_config,
155-
ac_config=ac_config,
156-
dump_folder=dump_folder,
157-
)
158-
model_parts[i] = m
159-
# NOTE: this is to update the model in the stage
160-
# in case the model is modified e.g. by torch.compile
161-
stages[i].submod = m
162-
163-
pp_schedule = build_pipeline_schedule(
164-
parallelism=parallelism,
165-
local_batch_size=training.local_batch_size,
166-
stages=stages,
167-
loss_fn=loss_fn,
168-
)
169-
170-
# This is used in the train loop to determine whether to pass in the input_ids and labels
171-
has_first_stage = False
172-
has_last_stage = False
173-
for stage in stages:
174-
if stage.is_first:
175-
has_first_stage = True
176-
if stage.is_last:
177-
has_last_stage = True
178-
179-
return pp_schedule, model_parts, has_first_stage, has_last_stage
188+
return num_virtual_stages, num_layers, input_weight, output_weight
180189

181190

182191
def build_pipeline_schedule(
@@ -371,6 +380,106 @@ def generate_llm_fqn_per_model_part(
371380
return module_names_per_stage
372381

373382

383+
def split_module(
384+
whole_model: nn.Module,
385+
module_names: list[str],
386+
) -> nn.Module:
387+
"""
388+
Splits a whole model into a module based on the specified module names.
389+
390+
Args:
391+
whole_model: The complete model to be split
392+
module_names: List of module names to include in the split
393+
394+
Returns:
395+
The split module
396+
397+
Example usage:
398+
module_names = ["tok_embeddings", "layers.0", "layers.1", "norm", "output"]
399+
split_module(whole_model, module_names)
400+
"""
401+
model = copy.deepcopy(whole_model)
402+
# Create a set of modules to keep for faster lookup
403+
modules_to_keep = set(module_names)
404+
for module_name, module_value in model.named_children():
405+
# Handle layer-like structures (e.g., "layers.0", "layers.1")
406+
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
407+
layers_to_keep = {
408+
name.split(".", 1)[1]
409+
for name in modules_to_keep
410+
if name.startswith(f"{module_name}.")
411+
}
412+
if layers_to_keep:
413+
# Keep only specified layers
414+
if isinstance(module_value, nn.ModuleDict):
415+
for layer_name in list(module_value.keys()):
416+
if layer_name not in layers_to_keep:
417+
del module_value[layer_name]
418+
elif isinstance(module_value, nn.ModuleList):
419+
indices_to_keep = {
420+
int(idx) for idx in layers_to_keep if idx.isdigit()
421+
}
422+
new_layers = nn.ModuleList(
423+
[
424+
layer
425+
for i, layer in enumerate(module_value)
426+
if i in indices_to_keep
427+
]
428+
)
429+
setattr(model, module_name, new_layers)
430+
else:
431+
# No layers from this structure needed, set to empty structure
432+
if isinstance(module_value, nn.ModuleDict):
433+
setattr(model, module_name, nn.ModuleDict())
434+
elif isinstance(module_value, nn.ModuleList):
435+
setattr(model, module_name, nn.ModuleList())
436+
# Handle simple module attributes (e.g., "linear", "norm")
437+
elif module_name not in modules_to_keep:
438+
# Replace with None
439+
setattr(model, module_name, None)
440+
return model
441+
442+
443+
def get_pp_rank_to_stage_indices_mapping(
444+
pp_rank: int,
445+
pp_degree,
446+
pp_schedule: str,
447+
num_stages: int,
448+
) -> tuple[int, ...]:
449+
"""
450+
Returns a mapping from PP rank to stage indices for the given pipeline schedule.
451+
452+
Args:
453+
pp_rank: Pipeline parallel rank
454+
pp_degree: Number of pipeline parallel ranks
455+
pp_schedule: Name of pipeline parallelism schedule
456+
num_stages: Number of pipeline stages
457+
458+
Returns:
459+
Mapping from PP rank to stage indices
460+
"""
461+
schedule_class = get_schedule_class(pp_schedule)
462+
style = (
463+
"v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop"
464+
)
465+
assert (
466+
num_stages % pp_degree == 0
467+
), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}"
468+
stages_per_rank = num_stages // pp_degree
469+
if style == "loop":
470+
return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank))
471+
elif style == "v":
472+
assert (
473+
stages_per_rank == 2
474+
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
475+
stage_v_pairs = list(
476+
zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1))
477+
)
478+
return tuple(stage_v_pairs[pp_rank])
479+
else:
480+
raise ValueError(f"Unknown style {style}")
481+
482+
374483
def pipeline_module_split(
375484
whole_model: nn.Module,
376485
pp_mesh: DeviceMesh,
@@ -412,97 +521,21 @@ def pipeline_module_split(
412521
"""
413522
pp_rank = pp_mesh.get_local_rank()
414523
pp_degree = pp_mesh.size()
415-
416-
def _build_stage_from_modules(
417-
stage_idx: int, module_names: list[str], num_stages: int
418-
) -> tuple[PipelineStage, nn.Module]:
419-
model = copy.deepcopy(whole_model)
420-
421-
# Create a set of modules to keep for faster lookup
422-
modules_to_keep = set(module_names)
423-
for module_name, module_value in model.named_children():
424-
# Handle layer-like structures (e.g., "layers.0", "layers.1")
425-
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
426-
layers_to_keep = {
427-
name.split(".", 1)[1]
428-
for name in modules_to_keep
429-
if name.startswith(f"{module_name}.")
430-
}
431-
if layers_to_keep:
432-
# Keep only specified layers
433-
if isinstance(module_value, nn.ModuleDict):
434-
for layer_name in list(module_value.keys()):
435-
if layer_name not in layers_to_keep:
436-
del module_value[layer_name]
437-
elif isinstance(module_value, nn.ModuleList):
438-
indices_to_keep = {
439-
int(idx) for idx in layers_to_keep if idx.isdigit()
440-
}
441-
new_layers = ModuleList(
442-
[
443-
layer
444-
for i, layer in enumerate(module_value)
445-
if i in indices_to_keep
446-
]
447-
)
448-
setattr(model, module_name, new_layers)
449-
else:
450-
# No layers from this structure needed, set to empty structure
451-
if isinstance(module_value, nn.ModuleDict):
452-
setattr(model, module_name, ModuleDict())
453-
elif isinstance(module_value, nn.ModuleList):
454-
setattr(model, module_name, ModuleList())
455-
# Handle simple module attributes (e.g., "linear", "norm")
456-
elif module_name not in modules_to_keep:
457-
# Replace with None
458-
setattr(model, module_name, None)
459-
460-
stage = PipelineStage(
461-
model,
462-
stage_idx,
463-
num_stages,
464-
device,
465-
group=pp_mesh.get_group("pp"),
466-
)
467-
return stage, model
468-
469524
num_stages = len(module_names_per_stage)
470525
stages = []
471526
models = []
472-
473-
schedule_class = get_schedule_class(pp_schedule)
474-
style = (
475-
"v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop"
527+
pp_rank_to_stage_indices = get_pp_rank_to_stage_indices_mapping(
528+
pp_rank, pp_degree, pp_schedule, num_stages
476529
)
477-
478-
def _get_stage_indices() -> tuple[int, ...]:
479-
"""
480-
Compute the stage ids for the stages that will run on this pp rank
481-
for either a looped or V style schedule
482-
"""
483-
assert (
484-
num_stages % pp_degree == 0
485-
), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}"
486-
stages_per_rank = num_stages // pp_degree
487-
if style == "loop":
488-
return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank))
489-
elif style == "v":
490-
assert (
491-
stages_per_rank == 2
492-
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
493-
stage_v_pairs = list(
494-
zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1))
495-
)
496-
return stage_v_pairs[pp_rank]
497-
else:
498-
raise ValueError(f"Unknown style {style}")
499-
500-
for stage_idx in _get_stage_indices():
530+
for stage_idx in pp_rank_to_stage_indices:
501531
module_names = module_names_per_stage[stage_idx]
502-
stage, model_chunk = _build_stage_from_modules(
532+
model_chunk = split_module(whole_model, module_names)
533+
stage = PipelineStage(
534+
model_chunk,
503535
stage_idx,
504-
module_names,
505536
num_stages,
537+
device,
538+
group=pp_mesh.get_group("pp"),
506539
)
507540
logger.info(
508541
f"PP rank {pp_rank} is building stage_idx {stage_idx} "

0 commit comments

Comments
 (0)