Skip to content

Commit e5ef99a

Browse files
authored
Update PP to release memory earlier (#1922)
Uses the API added in pytorch/pytorch#165822, since we do not return any output from PP step(). This allows us to release the memory earlier,
1 parent b1644a4 commit e5ef99a

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

torchtitan/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,12 +460,14 @@ def forward_backward_step(
460460
**extra_kwargs,
461461
target=targets,
462462
losses=losses,
463+
return_outputs=False,
463464
)
464465
else:
465466
self.pp_schedule.step(
466467
**extra_kwargs,
467468
target=targets,
468469
losses=losses,
470+
return_outputs=False,
469471
)
470472

471473
# accumulate losses across pipeline microbatches

0 commit comments

Comments
 (0)