Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 9514b54

Browse files
committed
Remove mbs
1 parent 5951e39 commit 9514b54

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

dist_run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
359359
output_args=example_outputs,
360360
group=pp_group,
361361
)
362+
363+
# Create schedule
362364
# Number of micro-batches for the schedule is 1, because each step() call we
363365
# only push 1 micro-batch into the pipeline. But we can continuously push
364366
# new micro-batches into the pipeline as they arrive, achieving same
365367
# pipelining effect.
366-
mbs = 1
367-
# create schedule
368-
prefiller = ScheduleGPipe(prefill_stage, mbs)
368+
prefiller = ScheduleGPipe(prefill_stage, 1)
369369

370370
prompt = [
371371
"What is a computer?",
@@ -456,7 +456,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
456456
group=pp_group,
457457
)
458458
# create schedule
459-
decorder = ScheduleGPipe(decode_stage, mbs)
459+
decorder = ScheduleGPipe(decode_stage, 1)
460460

461461
# Decoding
462462
with torch.no_grad(), CUDATrackTime() as timer:

0 commit comments

Comments
 (0)