Skip to content

Commit 5e2033c

Browse files
authored
[PP] Add microbatch_size and remove num_microbatches config (#1073)
landing #1050
1 parent 8b737fe commit 5e2033c

File tree

3 files changed

+16
-18
lines changed

3 files changed

+16
-18
lines changed

tests/integration_tests.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ def build_test_list():
156156
[
157157
"--parallelism.pipeline_parallel_degree 2",
158158
"--parallelism.pipeline_parallel_schedule ZBVZeroBubble",
159-
"--parallelism.pipeline_parallel_microbatches 8",
160159
],
161160
],
162161
"PP zero bubble test (v shaped)",
@@ -280,7 +279,6 @@ def build_test_list():
280279
"--parallelism.pipeline_parallel_degree 2",
281280
"--parallelism.pipeline_parallel_schedule PipelineScheduleMulti",
282281
"--parallelism.pipeline_parallel_schedule_csv ./tests/assets/custom_schedule.csv",
283-
"--parallelism.pipeline_parallel_microbatches 8",
284282
],
285283
],
286284
"PP with custom pipeline schedule loaded from CSV file",

torchtitan/config_manager.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -520,15 +520,16 @@ def __init__(self):
520520
""",
521521
)
522522
self.parser.add_argument(
523-
"--parallelism.pipeline_parallel_microbatches",
523+
"--parallelism.pipeline_parallel_microbatch_size",
524524
type=int,
525-
default=None,
525+
default=1,
526526
help="""
527-
How many microbatches to split the global training batch into when using pipeline parallelism.
527+
The size of each pipeline parallel microbatch (default 1).
528528
529-
The global training batch size must be evenly divisible by the number of microbatches.
529+
This value is used to compute the total number of microbatches by dividing batch_size with
530+
pipeline_parallel_microbatch_size.
530531
531-
The default value will be the number of pipeline stages, if unspecified.
532+
The global training batch size must be evenly divisible by pipeline_parallel_microbatch_size.
532533
""",
533534
)
534535
self.parser.add_argument(

torchtitan/distributed/pipeline.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -139,24 +139,23 @@ def build_pipeline_schedule(
139139
)
140140

141141
looped_schedule = issubclass(schedule_class, PipelineScheduleMulti)
142-
n_microbatches = job_config.parallelism.pipeline_parallel_microbatches
142+
microbatch_size = job_config.parallelism.pipeline_parallel_microbatch_size
143+
batch_size = job_config.training.batch_size
144+
# validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training
145+
if batch_size % microbatch_size != 0:
146+
raise ValueError(
147+
f"Batch size {job_config.training.batch_size} must be divisible by number of microbatches {n_microbatches}. "
148+
"Update the config arguments for either batch_size or pipeline_parallel_microbatch_size."
149+
)
150+
n_microbatches = batch_size // microbatch_size
143151
# We expect that the number of local stages (`len(stages)`) is the same across all ranks
144152
num_total_stages = job_config.parallelism.pipeline_parallel_degree * len(stages)
145-
if n_microbatches is None:
146-
n_microbatches = num_total_stages
147-
elif n_microbatches < num_total_stages:
153+
if n_microbatches < num_total_stages:
148154
logger.warning(
149155
f"Number of microbatches ({n_microbatches}) is less than the total number "
150156
f"of stages ({num_total_stages}) which may result in a bubble in the pipeline."
151157
)
152158

153-
# validate that the batch size is divisible by the number of microbatches otherwise we'll hang or error during training
154-
if job_config.training.batch_size % n_microbatches != 0:
155-
raise ValueError(
156-
f"Batch size {job_config.training.batch_size} must be divisible by number of microbatches {n_microbatches}. "
157-
"Update the config arguments for either batch_size or pipeline_parallel_microbatches."
158-
)
159-
160159
schedule = schedule_class(
161160
stages if looped_schedule else stages[0],
162161
n_microbatches=n_microbatches,

0 commit comments

Comments
 (0)