Skip to content

Commit 8b737fe

Browse files
authored
[PP] Add pipeline_parallel_layers_per_stage config (#1072)
landing #1041
1 parent 5568431 commit 8b737fe

File tree

4 files changed

+89
-33
lines changed

4 files changed

+89
-33
lines changed

tests/integration_tests.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,12 @@ def build_test_list():
194194
"--parallelism.pipeline_parallel_schedule 1F1B",
195195
"--parallelism.data_parallel_shard_degree 2",
196196
],
197+
[
198+
"--parallelism.pipeline_parallel_degree 2",
199+
"--parallelism.pipeline_parallel_schedule 1F1B",
200+
"--parallelism.pipeline_parallel_layers_per_stage 4",
201+
"--parallelism.data_parallel_shard_degree 2",
202+
],
197203
],
198204
"PP+DP 1F1B 2D test",
199205
"pp_dp_1f1b",
@@ -258,6 +264,11 @@ def build_test_list():
258264
"--parallelism.pipeline_parallel_degree 4",
259265
"--parallelism.pipeline_parallel_schedule Interleaved1F1B",
260266
],
267+
[
268+
"--parallelism.pipeline_parallel_degree 4",
269+
"--parallelism.pipeline_parallel_schedule Interleaved1F1B",
270+
"--parallelism.pipeline_parallel_layers_per_stage 1",
271+
],
261272
],
262273
"PP looped 1F1B test",
263274
"pp_looped_1f1b",

torchtitan/config_manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,15 @@ def __init__(self):
487487
Note: fully-automated splitting may be enabled in the future,
488488
but currently the split points must be specified manually.""",
489489
)
490+
self.parser.add_argument(
491+
"--parallelism.pipeline_parallel_layers_per_stage",
492+
type=int,
493+
default=None,
494+
help="""
495+
The number of layers per stage. If specified, the split points will be calculated from
496+
the number of layers and pipeline_parallel_degree. If not specified, the layers per stage will
497+
be inferred from the model, schedule, and pipeline_parallel_degree.""",
498+
)
490499
self.parser.add_argument(
491500
"--parallelism.pipeline_parallel_schedule",
492501
type=str,

torchtitan/distributed/pipeline.py

Lines changed: 68 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import math
67
import os
7-
from typing import Callable
8+
from typing import Callable, Optional
89

910
from torch.distributed.pipelining.schedules import (
1011
_PipelineSchedule,
@@ -25,50 +26,84 @@
2526
# TODO: It's unclear if this API is general enough to be used by other models.
2627
# If not, we should move it to a Transformer-specific directory.
2728
def generate_split_points(
28-
pipeline_parallel_schedule: str, pp_dim: int, num_layers: int
29+
schedule_str: str,
30+
layers_per_stage: Optional[int],
31+
pp_dim: int,
32+
num_layers: int,
33+
input_weight: int = 1,
34+
output_weight: int = 1,
2935
) -> list[str]:
3036
"""
31-
Generate a default split point based on the number of layers and
32-
pipeline parallel dimension.
37+
Generate a list of split points based on the number of layers and
38+
pipeline parallel dimension, ensuring the first and last stages have the least layers.
3339
3440
Args:
35-
job_config (JobConfig): The job configuration.
41+
schedule_str (str): The string of the schedule name.
42+
layers_per_stage (int): The number of layers per stage.
3643
pp_dim (int): The pipeline parallel dimension.
3744
num_layers (int): The number of layers in the model.
45+
input_output_weight (int): The number of layers to consider the input/output modules in the layer calculation.
3846
3947
Returns:
4048
list[str]: A list of split point FQNs.
4149
"""
4250

43-
schedule_class = get_schedule_class(pipeline_parallel_schedule)
44-
if issubclass(schedule_class, PipelineScheduleSingle):
45-
num_stages_per_rank = 1
46-
elif issubclass(schedule_class, PipelineScheduleMulti):
47-
# Multi-stage schedules support more than 2 stages per rank, but this is the default if
48-
# no pipeline split is specified
49-
num_stages_per_rank = 2
51+
schedule_class = get_schedule_class(schedule_str)
52+
is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle)
53+
num_stages_per_rank = 1 if is_single_stage_schedule else 2
54+
55+
if layers_per_stage is not None:
56+
total_stages = math.ceil(num_layers / layers_per_stage)
57+
if total_stages % pp_dim != 0:
58+
raise ValueError(
59+
f"Number of stages ({total_stages}) must be divisible by the pipeline parallel dimension ({pp_dim})."
60+
f"Each rank should have the same number of stages. "
61+
)
62+
num_stages_per_rank = total_stages // pp_dim
63+
64+
if is_single_stage_schedule and num_stages_per_rank != 1:
65+
raise ValueError(
66+
f"Number of stages per rank ({num_stages_per_rank}) must be 1 for single stage schedules."
67+
)
68+
elif not is_single_stage_schedule and num_stages_per_rank < 2:
69+
raise ValueError(
70+
f"Number of stages per rank ({num_stages_per_rank}) must be >= 2 for multi stage schedules."
71+
)
5072
else:
51-
raise ValueError(f"Unsupported pipeline schedule: {pipeline_parallel_schedule}")
52-
total_stages = pp_dim * num_stages_per_rank
53-
if total_stages > num_layers:
54-
raise ValueError("Total stages cannot be greater than the number of layers")
55-
56-
base_interval = num_layers // total_stages
57-
extra_layers = num_layers % total_stages
58-
59-
splits = []
60-
current_layer = 0
61-
for i in range(total_stages - 1):
62-
if i == 0:
63-
current_layer += base_interval
64-
else:
65-
# Middle stages get an extra layer if there are any remaining
66-
if extra_layers > 0:
67-
current_layer += base_interval + 1
68-
extra_layers -= 1
69-
else:
70-
current_layer += base_interval
71-
splits.append("layers." + str(current_layer))
73+
total_stages = pp_dim * num_stages_per_rank
74+
if total_stages > num_layers:
75+
raise ValueError("Total stages cannot be greater than the number of layers")
76+
77+
# Calculate effective number of layers including input and output weights
78+
effective_num_layers = num_layers + input_weight + output_weight
79+
base_layers_per_stage = effective_num_layers // total_stages
80+
81+
splits = [""] * (total_stages - 1)
82+
current_layer_index = 0
83+
84+
# First stage
85+
layers_on_first_stage = max(0, base_layers_per_stage - input_weight)
86+
current_layer_index += layers_on_first_stage
87+
splits[0] = "layers." + str(current_layer_index)
88+
89+
# Last stage
90+
layers_on_last_stage = max(0, base_layers_per_stage - output_weight)
91+
splits[-1] = "layers." + str(num_layers - layers_on_last_stage)
92+
93+
# Middle stages
94+
remaining_layers = num_layers - layers_on_first_stage - layers_on_last_stage - 1
95+
middle_stages = len(splits) - 2
96+
layers_per_middle_stage = remaining_layers // middle_stages
97+
# split remainder evenly across middle stages
98+
remainder = remaining_layers % middle_stages
99+
100+
for i in range(1, middle_stages + 1):
101+
current_layer_index += layers_per_middle_stage
102+
if remainder > 0:
103+
current_layer_index += 1
104+
remainder -= 1
105+
splits[i] = "layers." + str(current_layer_index)
106+
72107
logger.info(
73108
f"No 'pipeline_parallel_split_points' provided so the generated splits are: {splits} "
74109
"This may be sub-optimal as the number of layers per stage may be unbalanced."

torchtitan/models/llama3/pipeline_llama.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def pipeline_llama_manual_split(
9494

9595
splits = parallelism_config.pipeline_parallel_split_points or generate_split_points(
9696
parallelism_config.pipeline_parallel_schedule,
97+
parallelism_config.pipeline_parallel_layers_per_stage,
9798
parallel_dims.pp,
9899
model_config.n_layers,
99100
)

0 commit comments

Comments
 (0)