@@ -139,24 +139,23 @@ def build_pipeline_schedule(
139139 )
140140
141141 looped_schedule = issubclass (schedule_class , PipelineScheduleMulti )
142- n_microbatches = job_config .parallelism .pipeline_parallel_microbatches
142+ microbatch_size = job_config .parallelism .pipeline_parallel_microbatch_size
143+ batch_size = job_config .training .batch_size
144+ # validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training
145+ if batch_size % microbatch_size != 0 :
146+ raise ValueError (
147+ f"Batch size { job_config .training .batch_size } must be divisible by number of microbatches { n_microbatches } . "
148+ "Update the config arguments for either batch_size or pipeline_parallel_microbatch_size."
149+ )
150+ n_microbatches = batch_size // microbatch_size
143151 # We expect that the number of local stages (`len(stages)`) is the same across all ranks
144152 num_total_stages = job_config .parallelism .pipeline_parallel_degree * len (stages )
145- if n_microbatches is None :
146- n_microbatches = num_total_stages
147- elif n_microbatches < num_total_stages :
153+ if n_microbatches < num_total_stages :
148154 logger .warning (
149155 f"Number of microbatches ({ n_microbatches } ) is less than the total number "
150156 f"of stages ({ num_total_stages } ) which may result in a bubble in the pipeline."
151157 )
152158
153- # validate that the batch size is divisible by the number of microbatches otherwise we'll hang or error during training
154- if job_config .training .batch_size % n_microbatches != 0 :
155- raise ValueError (
156- f"Batch size { job_config .training .batch_size } must be divisible by number of microbatches { n_microbatches } . "
157- "Update the config arguments for either batch_size or pipeline_parallel_microbatches."
158- )
159-
160159 schedule = schedule_class (
161160 stages if looped_schedule else stages [0 ],
162161 n_microbatches = n_microbatches ,
0 commit comments