Skip to content

Commit 2d2fb54

Browse files
Sanket Jayant Purandaresanketpurandare
authored andcommitted
Refactor pipeline_parallel.py for graph PP reuse
1 parent ae1ab48 commit 2d2fb54

File tree

1 file changed

+37
-16
lines changed

1 file changed

+37
-16
lines changed

torchtitan/distributed/pipeline_parallel.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
"build_pipeline_schedule",
4444
"generate_llm_fqn_per_model_part",
4545
"pipeline_module_split",
46+
"split_module",
47+
"get_pp_rank_to_stage_indices_mapping",
4648
]
4749

4850

@@ -64,10 +66,10 @@ def pipeline_llm(
6466
pp_mesh = parallel_dims.get_mesh("pp")
6567

6668
num_virtual_stages, num_layers, input_weight, output_weight = get_pipeline_metadata(
67-
parallel_dims, job_config, model_args
69+
parallel_dims, parallelism, model_config
6870
)
6971

70-
module_names_per_stage = job_config.parallelism.module_fqns_per_model_part
72+
module_names_per_stage = parallelism.module_fqns_per_model_part
7173
if module_names_per_stage is None:
7274
module_names_per_stage = generate_llm_fqn_per_model_part(
7375
num_virtual_stages, num_layers, input_weight, output_weight
@@ -78,7 +80,7 @@ def pipeline_llm(
7880
stages, model_parts = pipeline_module_split(
7981
model,
8082
pp_mesh,
81-
job_config.parallelism.pipeline_parallel_schedule,
83+
parallelism.pipeline_parallel_schedule,
8284
device,
8385
module_names_per_stage,
8486
)
@@ -88,13 +90,27 @@ def pipeline_llm(
8890
# optimizer, and checkpointing
8991
for i, m in enumerate(model_parts):
9092
# apply SPMD-style PT-D techniques
91-
m = parallelize_fn(m, parallel_dims, job_config)
93+
m = parallelize_fn(
94+
m,
95+
parallel_dims=parallel_dims,
96+
training=training,
97+
model_converters=model_converters,
98+
parallelism=parallelism,
99+
compile_config=compile_config,
100+
ac_config=ac_config,
101+
dump_folder=dump_folder,
102+
)
92103
model_parts[i] = m
93104
# NOTE: this is to update the model in the stage
94105
# in case the model is modified e.g. by torch.compile
95106
stages[i].submod = m
96107

97-
pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
108+
pp_schedule = build_pipeline_schedule(
109+
parallelism=parallelism,
110+
local_batch_size=training.local_batch_size,
111+
stages=stages,
112+
loss_fn=loss_fn,
113+
)
98114

99115
# This is used in the train loop to determine whether to pass in the input_ids and labels
100116
has_first_stage = False
@@ -110,16 +126,16 @@ def pipeline_llm(
110126

111127
def get_pipeline_metadata(
112128
parallel_dims: ParallelDims,
113-
job_config: JobConfig,
114-
model_args: BaseModelArgs,
129+
parallelism: ParallelismConfig,
130+
model_config: BaseModel.Config,
115131
) -> tuple[int, int, int, int]:
116132
"""
117133
Determine the number of virtual stages and the number of layers in the model.
118134
119135
Args:
120136
parallel_dims (ParallelDims): Parallel dimensions.
121-
job_config (JobConfig): Job configuration.
122-
model_args (BaseModelArgs): Model arguments.
137+
parallelism (ParallelismConfig): Parallelism configuration.
138+
model_config (BaseModel.Config): Model configuration.
123139
124140
Returns:
125141
tuple: A tuple containing the number of virtual stages, the number of layers in the model,
@@ -194,6 +210,8 @@ def build_pipeline_schedule(
194210
local_batch_size: int,
195211
stages: list[PipelineStage],
196212
loss_fn: Callable,
213+
backward_requires_autograd: bool = True,
214+
scale_grads: bool = False,
197215
) -> _PipelineSchedule:
198216
"""Builds a pipeline schedule for the given job configuration and stages.
199217
@@ -242,7 +260,8 @@ def build_pipeline_schedule(
242260
stages if looped_schedule else stages[0],
243261
n_microbatches=n_microbatches,
244262
loss_fn=loss_fn,
245-
scale_grads=False,
263+
backward_requires_autograd=backward_requires_autograd,
264+
scale_grads=scale_grads,
246265
)
247266
logger.info(
248267
f"Using pipeline schedule {parallelism.pipeline_parallel_schedule} "
@@ -403,7 +422,9 @@ def split_module(
403422
modules_to_keep = set(module_names)
404423
for module_name, module_value in model.named_children():
405424
# Handle layer-like structures (e.g., "layers.0", "layers.1")
406-
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
425+
if isinstance(
426+
module_value, (nn.ModuleDict, nn.ModuleList, ModuleDict, ModuleList)
427+
):
407428
layers_to_keep = {
408429
name.split(".", 1)[1]
409430
for name in modules_to_keep
@@ -419,7 +440,7 @@ def split_module(
419440
indices_to_keep = {
420441
int(idx) for idx in layers_to_keep if idx.isdigit()
421442
}
422-
new_layers = nn.ModuleList(
443+
new_layers = ModuleList(
423444
[
424445
layer
425446
for i, layer in enumerate(module_value)
@@ -429,10 +450,10 @@ def split_module(
429450
setattr(model, module_name, new_layers)
430451
else:
431452
# No layers from this structure needed, set to empty structure
432-
if isinstance(module_value, nn.ModuleDict):
433-
setattr(model, module_name, nn.ModuleDict())
434-
elif isinstance(module_value, nn.ModuleList):
435-
setattr(model, module_name, nn.ModuleList())
453+
if isinstance(module_value, (nn.ModuleDict, ModuleDict)):
454+
setattr(model, module_name, ModuleDict())
455+
elif isinstance(module_value, (nn.ModuleList, ModuleList)):
456+
setattr(model, module_name, ModuleList())
436457
# Handle simple module attributes (e.g., "linear", "norm")
437458
elif module_name not in modules_to_keep:
438459
# Replace with None

0 commit comments

Comments
 (0)