Skip to content

Commit 5dad8d3

Browse files
stephenyan1231facebook-github-bot
authored andcommitted
add loss_backward_retain_graph to __init__() (#856)
Summary: Pull Request resolved: #856 Expose **retain_graph** kwarg in **loss.backward()** by adding a new argument **loss_backward_retain_graph** to **AutoUnit.__init__()** Reviewed By: JKSenthil Differential Revision: D58901158 fbshipit-source-id: cdc051e15d1831cbc69b9cab7e92c0f6adb72c67
1 parent aa34ffb commit 5dad8d3

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

torchtnt/framework/auto_unit.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,10 @@ class AutoUnit(
434434
activation_checkpoint_params: params for enabling activation checkpointing
435435
training: if True, the optimizer and optionally LR scheduler will be created after the class is initialized.
436436
enable_compiled_autograd: if True, `compiled_autograd` will be used to compile the backward, this is an experimental flag.
437+
loss_backward_retain_graph: If ``None`` or ``False``, the graph used to compute
438+
the grads will be freed during loss backward pass. Note that in nearly all cases setting
439+
this option to True is not needed and often can be worked around
440+
in a much more efficient way.
437441
438442
Note:
439443
Certain strategies, like :class:`~torchtnt.utils.prepare_module.FSDPStrategy` also support mixed precision as an argument, so can be configured through that class as well.
@@ -463,6 +467,7 @@ def __init__(
463467
activation_checkpoint_params: Optional[ActivationCheckpointParams] = None,
464468
training: bool = True,
465469
enable_compiled_autograd: bool = False,
470+
loss_backward_retain_graph: Optional[bool] = None,
466471
) -> None:
467472
super().__init__(
468473
module=module,
@@ -526,6 +531,7 @@ def __init__(
526531

527532
self.enable_compiled_autograd = enable_compiled_autograd
528533
self.training = training
534+
self.loss_backward_retain_graph = loss_backward_retain_graph
529535

530536
self.optimizer: Optional[torch.optim.Optimizer] = None
531537
self.lr_scheduler: Optional[TLRScheduler] = None
@@ -620,12 +626,14 @@ def maybe_enable_compiled_autograd(
620626
with get_timing_context(
621627
state, f"{self.__class__.__name__}.backward"
622628
):
623-
scaled_loss.backward()
629+
scaled_loss.backward(
630+
retain_graph=self.loss_backward_retain_graph
631+
)
624632
else:
625633
with get_timing_context(
626634
state, f"{self.__class__.__name__}.backward"
627635
):
628-
loss.backward()
636+
loss.backward(retain_graph=self.loss_backward_retain_graph)
629637

630638
total_grad_norm = None
631639
if should_update_weights:

0 commit comments

Comments
 (0)