Skip to content

Commit a7a3260

Browse files
Enable graph_pp for autoparallel in torchtitan
stack-info: PR: #2271, branch: sanketpurandare/stack/2
1 parent 6d42f9a commit a7a3260

File tree

2 files changed

+172
-127
lines changed

2 files changed

+172
-127
lines changed

torchtitan/distributed/pipeline_parallel.py

Lines changed: 172 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import copy
7-
87
import math
98
import os
109
from typing import Callable
@@ -13,7 +12,6 @@
1312
import torch.nn as nn
1413
from torch.distributed.device_mesh import DeviceMesh
1514
from torch.distributed.pipelining import PipelineStage
16-
1715
from torch.distributed.pipelining.schedules import (
1816
_PipelineSchedule,
1917
_PipelineScheduleRuntime,
@@ -24,7 +22,6 @@
2422
ScheduleDualPipeV,
2523
ScheduleZBVZeroBubble,
2624
)
27-
2825
from torchtitan.components.loss import LossFunction, rescale_accumulated_loss
2926
from torchtitan.config import JobConfig
3027
from torchtitan.distributed import ParallelDims
@@ -34,6 +31,7 @@
3431

3532
__all__ = [
3633
"pipeline_llm",
34+
"get_pipeline_metadata",
3735
"build_pipeline_schedule",
3836
"generate_llm_fqn_per_model_part",
3937
"pipeline_module_split",
@@ -51,6 +49,68 @@ def pipeline_llm(
5149
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
5250
pp_mesh = parallel_dims.get_mesh("pp")
5351

52+
num_virtual_stages, num_layers, input_weight, output_weight = get_pipeline_metadata(
53+
parallel_dims, job_config, model_args
54+
)
55+
56+
module_names_per_stage = job_config.parallelism.module_fqns_per_model_part
57+
if module_names_per_stage is None:
58+
module_names_per_stage = generate_llm_fqn_per_model_part(
59+
num_virtual_stages, num_layers, input_weight, output_weight
60+
)
61+
for i, stage_ms in enumerate(module_names_per_stage):
62+
logger.debug(f"Stage {i}: {stage_ms}")
63+
64+
stages, model_parts = pipeline_module_split(
65+
model,
66+
pp_mesh,
67+
job_config.parallelism.pipeline_parallel_schedule,
68+
device,
69+
module_names_per_stage,
70+
)
71+
72+
# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
73+
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
74+
# optimizer, and checkpointing
75+
for i, m in enumerate(model_parts):
76+
# apply SPMD-style PT-D techniques
77+
m = parallelize_fn(m, parallel_dims, job_config)
78+
model_parts[i] = m
79+
# NOTE: this is to update the model in the stage
80+
# in case the model is modified e.g. by torch.compile
81+
stages[i].submod = m
82+
83+
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
84+
85+
# This is used in the train loop to determine whether to pass in the input_ids and labels
86+
has_first_stage = False
87+
has_last_stage = False
88+
for stage in stages:
89+
if stage.is_first:
90+
has_first_stage = True
91+
if stage.is_last:
92+
has_last_stage = True
93+
94+
return pp_schedule, model_parts, has_first_stage, has_last_stage
95+
96+
97+
def get_pipeline_metadata(
98+
parallel_dims: ParallelDims,
99+
job_config: JobConfig,
100+
model_args: BaseModelArgs,
101+
) -> tuple[int, int, int, int]:
102+
"""
103+
Determine the number of virtual stages and the number of layers in the model.
104+
105+
Args:
106+
parallel_dims (ParallelDims): Parallel dimensions.
107+
job_config (JobConfig): Job configuration.
108+
model_args (BaseModelArgs): Model arguments.
109+
110+
Returns:
111+
tuple: A tuple containing the number of virtual stages, the number of layers in the model,
112+
the input weight, and the output weight.
113+
"""
54114
# Determine the number of virtual stages based on schedule type
55115
schedule_class = get_schedule_class(
56116
job_config.parallelism.pipeline_parallel_schedule
@@ -113,46 +173,7 @@ def pipeline_llm(
113173
# For single-stage schedules, default is 1 virtual stage per rank
114174
stages_per_rank = 1 if is_single_stage_schedule else 2
115175
num_virtual_stages = parallel_dims.pp * stages_per_rank
116-
117-
module_names_per_stage = job_config.parallelism.module_fqns_per_model_part
118-
if module_names_per_stage is None:
119-
module_names_per_stage = generate_llm_fqn_per_model_part(
120-
num_virtual_stages, num_layers, input_weight, output_weight
121-
)
122-
for i, stage_ms in enumerate(module_names_per_stage):
123-
logger.debug(f"Stage {i}: {stage_ms}")
124-
125-
stages, model_parts = pipeline_module_split(
126-
model,
127-
pp_mesh,
128-
job_config.parallelism.pipeline_parallel_schedule,
129-
device,
130-
module_names_per_stage,
131-
)
132-
133-
# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
134-
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
135-
# optimizer, and checkpointing
136-
for i, m in enumerate(model_parts):
137-
# apply SPMD-style PT-D techniques
138-
m = parallelize_fn(m, parallel_dims, job_config)
139-
model_parts[i] = m
140-
# NOTE: this is to update the model in the stage
141-
# in case the model is modified e.g. by torch.compile
142-
stages[i].submod = m
143-
144-
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
145-
146-
# This is used in the train loop to determine whether to pass in the input_ids and labels
147-
has_first_stage = False
148-
has_last_stage = False
149-
for stage in stages:
150-
if stage.is_first:
151-
has_first_stage = True
152-
if stage.is_last:
153-
has_last_stage = True
154-
155-
return pp_schedule, model_parts, has_first_stage, has_last_stage
176+
return num_virtual_stages, num_layers, input_weight, output_weight
156177

157178

158179
def build_pipeline_schedule(
@@ -344,6 +365,106 @@ def generate_llm_fqn_per_model_part(
344365
return module_names_per_stage
345366

346367

368+
def split_module(
369+
whole_model: nn.Module,
370+
module_names: list[str],
371+
) -> nn.Module:
372+
"""
373+
Splits a whole model into a module based on the specified module names.
374+
375+
Args:
376+
whole_model: The complete model to be split
377+
module_names: List of module names to include in the split
378+
379+
Returns:
380+
The split module
381+
382+
Example usage:
383+
module_names = ["tok_embeddings", "layers.0", "layers.1", "norm", "output"]
384+
split_module(whole_model, module_names)
385+
"""
386+
model = copy.deepcopy(whole_model)
387+
# Create a set of modules to keep for faster lookup
388+
modules_to_keep = set(module_names)
389+
for module_name, module_value in model.named_children():
390+
# Handle layer-like structures (e.g., "layers.0", "layers.1")
391+
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
392+
layers_to_keep = {
393+
name.split(".", 1)[1]
394+
for name in modules_to_keep
395+
if name.startswith(f"{module_name}.")
396+
}
397+
if layers_to_keep:
398+
# Keep only specified layers
399+
if isinstance(module_value, nn.ModuleDict):
400+
for layer_name in list(module_value.keys()):
401+
if layer_name not in layers_to_keep:
402+
del module_value[layer_name]
403+
elif isinstance(module_value, nn.ModuleList):
404+
indices_to_keep = {
405+
int(idx) for idx in layers_to_keep if idx.isdigit()
406+
}
407+
new_layers = nn.ModuleList(
408+
[
409+
layer
410+
for i, layer in enumerate(module_value)
411+
if i in indices_to_keep
412+
]
413+
)
414+
setattr(model, module_name, new_layers)
415+
else:
416+
# No layers from this structure needed, set to empty structure
417+
if isinstance(module_value, nn.ModuleDict):
418+
setattr(model, module_name, nn.ModuleDict())
419+
elif isinstance(module_value, nn.ModuleList):
420+
setattr(model, module_name, nn.ModuleList())
421+
# Handle simple module attributes (e.g., "linear", "norm")
422+
elif module_name not in modules_to_keep:
423+
# Replace with None
424+
setattr(model, module_name, None)
425+
return model
426+
427+
428+
def get_pp_rank_to_stage_indices_mapping(
429+
pp_rank: int,
430+
pp_degree,
431+
pp_schedule: str,
432+
num_stages: int,
433+
) -> tuple[int, ...]:
434+
"""
435+
Returns a mapping from PP rank to stage indices for the given pipeline schedule.
436+
437+
Args:
438+
pp_rank: Pipeline parallel rank
439+
pp_degree: Number of pipeline parallel ranks
440+
pp_schedule: Name of pipeline parallelism schedule
441+
num_stages: Number of pipeline stages
442+
443+
Returns:
444+
Mapping from PP rank to stage indices
445+
"""
446+
schedule_class = get_schedule_class(pp_schedule)
447+
style = (
448+
"v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop"
449+
)
450+
assert (
451+
num_stages % pp_degree == 0
452+
), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}"
453+
stages_per_rank = num_stages // pp_degree
454+
if style == "loop":
455+
return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank))
456+
elif style == "v":
457+
assert (
458+
stages_per_rank == 2
459+
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
460+
stage_v_pairs = list(
461+
zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1))
462+
)
463+
return tuple(stage_v_pairs[pp_rank])
464+
else:
465+
raise ValueError(f"Unknown style {style}")
466+
467+
347468
def pipeline_module_split(
348469
whole_model: nn.Module,
349470
pp_mesh: DeviceMesh,
@@ -385,97 +506,21 @@ def pipeline_module_split(
385506
"""
386507
pp_rank = pp_mesh.get_local_rank()
387508
pp_degree = pp_mesh.size()
388-
389-
def _build_stage_from_modules(
390-
stage_idx: int, module_names: list[str], num_stages: int
391-
) -> tuple[PipelineStage, nn.Module]:
392-
model = copy.deepcopy(whole_model)
393-
394-
# Create a set of modules to keep for faster lookup
395-
modules_to_keep = set(module_names)
396-
for module_name, module_value in model.named_children():
397-
# Handle layer-like structures (e.g., "layers.0", "layers.1")
398-
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
399-
layers_to_keep = {
400-
name.split(".", 1)[1]
401-
for name in modules_to_keep
402-
if name.startswith(f"{module_name}.")
403-
}
404-
if layers_to_keep:
405-
# Keep only specified layers
406-
if isinstance(module_value, nn.ModuleDict):
407-
for layer_name in list(module_value.keys()):
408-
if layer_name not in layers_to_keep:
409-
del module_value[layer_name]
410-
elif isinstance(module_value, nn.ModuleList):
411-
indices_to_keep = {
412-
int(idx) for idx in layers_to_keep if idx.isdigit()
413-
}
414-
new_layers = nn.ModuleList(
415-
[
416-
layer
417-
for i, layer in enumerate(module_value)
418-
if i in indices_to_keep
419-
]
420-
)
421-
setattr(model, module_name, new_layers)
422-
else:
423-
# No layers from this structure needed, set to empty structure
424-
if isinstance(module_value, nn.ModuleDict):
425-
setattr(model, module_name, nn.ModuleDict())
426-
elif isinstance(module_value, nn.ModuleList):
427-
setattr(model, module_name, nn.ModuleList())
428-
# Handle simple module attributes (e.g., "linear", "norm")
429-
elif module_name not in modules_to_keep:
430-
# Replace with None
431-
setattr(model, module_name, None)
432-
433-
stage = PipelineStage(
434-
model,
435-
stage_idx,
436-
num_stages,
437-
device,
438-
group=pp_mesh.get_group("pp"),
439-
)
440-
return stage, model
441-
442509
num_stages = len(module_names_per_stage)
443510
stages = []
444511
models = []
445-
446-
schedule_class = get_schedule_class(pp_schedule)
447-
style = (
448-
"v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop"
512+
pp_rank_to_stage_indices = get_pp_rank_to_stage_indices_mapping(
513+
pp_rank, pp_degree, pp_schedule, num_stages
449514
)
450-
451-
def _get_stage_indices() -> tuple[int, ...]:
452-
"""
453-
Compute the stage ids for the stages that will run on this pp rank
454-
for either a looped or V style schedule
455-
"""
456-
assert (
457-
num_stages % pp_degree == 0
458-
), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}"
459-
stages_per_rank = num_stages // pp_degree
460-
if style == "loop":
461-
return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank))
462-
elif style == "v":
463-
assert (
464-
stages_per_rank == 2
465-
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
466-
stage_v_pairs = list(
467-
zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1))
468-
)
469-
return stage_v_pairs[pp_rank]
470-
else:
471-
raise ValueError(f"Unknown style {style}")
472-
473-
for stage_idx in _get_stage_indices():
515+
for stage_idx in pp_rank_to_stage_indices:
474516
module_names = module_names_per_stage[stage_idx]
475-
stage, model_chunk = _build_stage_from_modules(
517+
model_chunk = split_module(whole_model, module_names)
518+
stage = PipelineStage(
519+
model_chunk,
476520
stage_idx,
477-
module_names,
478521
num_stages,
522+
device,
523+
group=pp_mesh.get_group("pp"),
479524
)
480525
logger.info(
481526
f"PP rank {pp_rank} is building stage_idx {stage_idx} "

torchtitan/experiments/autoparallel/graph_pp_builder.py

Whitespace-only changes.

0 commit comments

Comments
 (0)