Skip to content

Commit f9838b8

Browse files
rehm-gfacebook-github-bot
authored andcommitted
Add eval step/epoch early stopping to TNT callback (#920)
Summary: Pull Request resolved: #920 Add eval phase testing to the early stopping TNT callback Reviewed By: JKSenthil Differential Revision: D64004301 fbshipit-source-id: 81c23bde273cca140647693960dfb82224c95bef
1 parent 7bfdee4 commit f9838b8

File tree

2 files changed

+81
-5
lines changed

2 files changed

+81
-5
lines changed

tests/framework/callbacks/test_early_stopping.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@
1010
from typing import cast, Literal
1111
from unittest.mock import MagicMock, patch
1212

13-
from torchtnt.framework._test_utils import Batch, get_dummy_train_state
13+
from torchtnt.framework._test_utils import (
14+
Batch,
15+
get_dummy_eval_state,
16+
get_dummy_train_state,
17+
)
1418

1519
from torchtnt.framework.callbacks.early_stopping import EarlyStopping
1620
from torchtnt.framework.state import State
17-
from torchtnt.framework.unit import TrainUnit
21+
from torchtnt.framework.unit import EvalUnit, TrainUnit
1822

1923
from torchtnt.utils.early_stop_checker import EarlyStopChecker
2024

@@ -131,6 +135,48 @@ def test_interval_freq(self, _maybe_stop: MagicMock) -> None:
131135
esc.on_train_step_end(state, unit)
132136
_maybe_stop.assert_called_once()
133137

138+
@patch("torchtnt.framework.callbacks.early_stopping.EarlyStopping._maybe_stop")
139+
def test_phase(self, _maybe_stop: MagicMock) -> None:
140+
early_stop_checker = EarlyStopChecker(
141+
mode="min",
142+
patience=2,
143+
min_delta=0.0,
144+
)
145+
esc = EarlyStopping(
146+
monitored_attr="eval_loss",
147+
early_stop_checker=early_stop_checker,
148+
interval="epoch",
149+
interval_freq=2,
150+
phase="eval",
151+
)
152+
153+
state = get_dummy_eval_state()
154+
unit = MyEvalLossUnit()
155+
156+
unit.eval_progress.increment_epoch()
157+
esc.on_eval_epoch_end(state, unit)
158+
_maybe_stop.assert_not_called()
159+
unit.eval_progress.increment_epoch()
160+
esc.on_eval_epoch_end(state, unit)
161+
_maybe_stop.assert_called_once()
162+
163+
_maybe_stop.reset_mock()
164+
165+
esc = EarlyStopping(
166+
monitored_attr="eval_loss",
167+
early_stop_checker=early_stop_checker,
168+
interval="step",
169+
interval_freq=2,
170+
phase="eval",
171+
)
172+
173+
unit.eval_progress.increment_step()
174+
esc.on_eval_step_end(state, unit)
175+
_maybe_stop.assert_not_called()
176+
unit.eval_progress.increment_step()
177+
esc.on_eval_step_end(state, unit)
178+
_maybe_stop.assert_called_once()
179+
134180

135181
class MyTrainLossUnit(TrainUnit[Batch]):
136182
def __init__(self) -> None:
@@ -139,3 +185,12 @@ def __init__(self) -> None:
139185

140186
def train_step(self, state: State, data: Batch) -> None:
141187
return None
188+
189+
190+
class MyEvalLossUnit(EvalUnit[Batch]):
191+
def __init__(self) -> None:
192+
super().__init__()
193+
self.eval_loss = 0.01
194+
195+
def eval_step(self, state: State, data: Batch) -> None:
196+
return None

torchtnt/framework/callbacks/early_stopping.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from torchtnt.framework.callback import Callback
1212
from torchtnt.framework.state import State
13-
from torchtnt.framework.unit import AppStateMixin, TTrainUnit
13+
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainUnit
1414
from torchtnt.utils.distributed import get_global_rank, sync_bool
1515
from torchtnt.utils.early_stop_checker import EarlyStopChecker
1616

@@ -23,6 +23,7 @@ class EarlyStopping(Callback):
2323
monitored_attr: The attribute to monitor on the unit. Must be a float or tensor attribute on the unit.
2424
early_stop_checker: a :class:`~torchtnt.utils.early_stop_checker.EarlyStopChecker` to use for checking whether to stop early.
2525
interval: The interval to check the monitored attribute. Must be one of "step" or "epoch".
26+
phase: The phase to check the monitored attribute. Must be one of "train" or "eval".
2627
2728
Note:
2829
If doing distributed training, this callback checks the metric value only on rank 0
@@ -33,29 +34,49 @@ def __init__(
3334
monitored_attr: str,
3435
early_stop_checker: EarlyStopChecker,
3536
interval: Literal["step", "epoch"] = "epoch",
37+
phase: Literal["train", "eval"] = "train",
3638
interval_freq: int = 1,
3739
) -> None:
3840
self._monitored_attr = monitored_attr
3941
self._esc = early_stop_checker
4042
self._interval = interval
4143
self._interval_freq = interval_freq
44+
self._phase = phase
4245

4346
self._rank: int = get_global_rank()
4447

4548
def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
4649
if (
47-
self._interval == "step"
50+
self._phase == "train"
51+
and self._interval == "step"
4852
and unit.train_progress.num_steps_completed % self._interval_freq == 0
4953
):
5054
self._maybe_stop(state, unit)
5155

5256
def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
5357
if (
54-
self._interval == "epoch"
58+
self._phase == "train"
59+
and self._interval == "epoch"
5560
and unit.train_progress.num_epochs_completed % self._interval_freq == 0
5661
):
5762
self._maybe_stop(state, unit)
5863

64+
def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
65+
if (
66+
self._phase == "eval"
67+
and self._interval == "step"
68+
and unit.eval_progress.num_steps_completed % self._interval_freq == 0
69+
):
70+
self._maybe_stop(state, unit)
71+
72+
def on_eval_epoch_end(self, state: State, unit: TEvalUnit) -> None:
73+
if (
74+
self._phase == "eval"
75+
and self._interval == "epoch"
76+
and unit.eval_progress.num_epochs_completed % self._interval_freq == 0
77+
):
78+
self._maybe_stop(state, unit)
79+
5980
def _maybe_stop(self, state: State, unit: AppStateMixin) -> None:
6081
"""
6182
Checks whether to stop early based on the monitored attribute.

0 commit comments

Comments
 (0)