Skip to content

Commit 4c90a5f

Browse files
galrotemfacebook-github-bot
authored andcommitted
autounit - extract lr scheduler step to a separate method (#778)
Summary: Pull Request resolved: #778 Reviewed By: JKSenthil Differential Revision: D55754227 fbshipit-source-id: 4d2ea236359001585debcd954b1a11606cafae42
1 parent ec33f1b commit 4c90a5f

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

torchtnt/framework/auto_unit.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,8 @@ class AutoUnit(
408408
- ``on_eval_step_end``
409409
- ``on_predict_step_end``
410410
411+
The user can also override the LR step method, ``step_lr_scheduler``, in case they want to have custom logic.
412+
411413
Then use with the :py:func:`~torchtnt.framework.train`, :py:func:`~torchtnt.framework.evaluate`, :py:func:`~torchtnt.framework.fit`, or
412414
:py:func:`~torchtnt.framework.predict` entry point as normal.
413415
@@ -731,6 +733,12 @@ def on_predict_step_end(
731733
"""
732734
pass
733735

736+
def step_lr_scheduler(self) -> None:
737+
"""
738+
LR step method extracted to a method in case the user wants to override
739+
"""
740+
none_throws(self.lr_scheduler).step()
741+
734742
def _update_weights(self, state: State) -> Optional[torch.Tensor]:
735743
"""
736744
Updates weights of the module, handles clip gradient norm, etc.
@@ -865,12 +873,11 @@ def _update_lr_and_swa(self, state: State, number_of_steps_or_epochs: int) -> No
865873
none_throws(self.swa_scheduler).step()
866874
else:
867875
# optionally step lr scheduler if SWA not in use
868-
lr_scheduler = self.lr_scheduler
869-
if lr_scheduler:
876+
if self.lr_scheduler:
870877
with get_timing_context(
871878
state, f"{self.__class__.__name__}.lr_scheduler_step"
872879
):
873-
lr_scheduler.step()
880+
self.step_lr_scheduler()
874881

875882

876883
def _validate_torch_compile_available() -> None:

0 commit comments

Comments
 (0)