Skip to content

Commit 2e6cd59

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
Move zero grads logic at the beginning of train step (#974)
Summary: Pull Request resolved: #974 Reviewed By: diego-urgell Differential Revision: D69117224 fbshipit-source-id: fa3d6177c4aa4e5a19a3c803ca12d8898ce1b23e
1 parent 7390c77 commit 2e6cd59

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

torchtnt/framework/auto_unit.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ class AutoUnit(
471471
this option to True is not needed and often can be worked around
472472
in a much more efficient way.
473473
enable_prefetch: if True, the data will be prefetched to the device before the next batch is loaded
474+
zero_grad_at_train_step_start: if True, the optimizer's gradients will be zeroed at the start of each train step, rather than at the end. Useful if you want to inspect/log the gradients via custom callback.
474475
475476
Note:
476477
Certain strategies, like :class:`~torchtnt.utils.prepare_module.FSDPStrategy` also support mixed precision as an argument, so can be configured through that class as well.
@@ -506,6 +507,7 @@ def __init__(
506507
enable_compiled_autograd: bool = False,
507508
loss_backward_retain_graph: Optional[bool] = None,
508509
enable_prefetch: bool = True,
510+
zero_grad_at_train_step_start: bool = False,
509511
) -> None:
510512
super().__init__(
511513
module=module,
@@ -576,6 +578,10 @@ def __init__(
576578
self.lr_scheduler: Optional[TLRScheduler] = None
577579
self.swa_scheduler: Optional[SWALR] = None
578580

581+
self.zero_grad_at_train_step_start: bool = zero_grad_at_train_step_start
582+
# keep track of when to zero grad at train step start
583+
self._weight_updated_in_prev_step = False
584+
579585
def __setattr__(self, name: str, value: object) -> None:
580586
if isinstance(value, torch.nn.Module):
581587
self._validate_module_attr(name, value)
@@ -653,6 +659,11 @@ def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
653659
self.train_progress.num_steps_completed_in_epoch + 1
654660
) % self.gradient_accumulation_steps == 0 or self._is_last_batch
655661

662+
# zero the gradients if previous step updated weights
663+
if self._weight_updated_in_prev_step and self.zero_grad_at_train_step_start:
664+
self.zero_grad(state)
665+
self._weight_updated_in_prev_step = False
666+
656667
# for pyre, assign to local variable
657668
module = self.module
658669

@@ -829,21 +840,24 @@ def step_lr_scheduler(self) -> None:
829840
"""
830841
none_throws(self.lr_scheduler).step()
831842

832-
def zero_grad(self) -> None:
843+
def zero_grad(self, state: State) -> None:
833844
"""
834-
Zeroes the gradients of the module's parameters. Override this if you need to log the gradients before zeroing them.
845+
Zeroes the gradients of the module's parameters. You can override this if you want to log the gradients before zeroing them.
835846
836847
Example of overriding:
837848
class CustomAutoUnit(MyAutoUnit):
838849
...
839850
840-
def zero_grad(self):
851+
def zero_grad(self, state):
841852
# log before zeroing gradients
842853
super().zero_grad()
843854
"""
844855

845856
optimizer = none_throws(self.optimizer)
846-
optimizer.zero_grad(set_to_none=True)
857+
with get_timing_context(
858+
state, f"{self.__class__.__name__}.optimizer_zero_grad"
859+
):
860+
optimizer.zero_grad(set_to_none=True)
847861

848862
def _update_weights(self, state: State) -> Optional[torch.Tensor]:
849863
"""
@@ -904,11 +918,12 @@ def _update_weights(self, state: State) -> Optional[torch.Tensor]:
904918
else:
905919
optimizer.step()
906920

907-
# sets gradients to zero
908-
with get_timing_context(
909-
state, f"{self.__class__.__name__}.optimizer_zero_grad"
910-
):
911-
self.zero_grad()
921+
if self.zero_grad_at_train_step_start:
922+
# mark that weights were updated in this step
923+
# so in next step we know to zero the gradients
924+
self._weight_updated_in_prev_step = True
925+
else:
926+
self.zero_grad(state)
912927

913928
if self.step_lr_interval == "step":
914929
self._update_lr_and_swa(state, self.train_progress.num_steps_completed)

0 commit comments

Comments
 (0)