Skip to content

Commit 1918819

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Warn when early stop checker returns True (#994)
Summary: Pull Request resolved: #994 Reviewed By: JKSenthil Differential Revision: D73955323 fbshipit-source-id: fb967c02cd684b9989cf16cd6f608518e03ad502
1 parent 59bfa39 commit 1918819

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

torchtnt/framework/callbacks/early_stopping.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import logging
910
from typing import Literal
1011

1112
from torchtnt.framework.callback import Callback
@@ -14,6 +15,8 @@
1415
from torchtnt.utils.distributed import get_global_rank, sync_bool
1516
from torchtnt.utils.early_stop_checker import EarlyStopChecker
1617

18+
logger: logging.Logger = logging.getLogger(__name__)
19+
1720

1821
class EarlyStopping(Callback):
1922
"""
@@ -102,4 +105,5 @@ def _maybe_stop(self, state: State, unit: AppStateMixin) -> None:
102105

103106
should_stop = sync_bool(should_stop, coherence_mode="rank_zero")
104107
if should_stop:
108+
logger.warning("Stopping training early due to early stopping criteria.")
105109
state.stop()

torchtnt/utils/early_stop_checker.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,15 +189,15 @@ def check(self, val: Union[torch.Tensor, float, int]) -> bool:
189189

190190
# Check finite
191191
if self.check_finite and not torch.isfinite(val):
192-
_log.debug(
192+
_log.warning(
193193
f"Metric is not finite: {val}."
194194
f" Previous best value was {self._best_value}."
195195
)
196196
return True
197197

198198
# Check if reached stopping threshold
199199
if stopping_threshold is not None and self._mode_func(val, stopping_threshold):
200-
_log.debug(
200+
_log.warning(
201201
"Stopping threshold reached:"
202202
f" {val} {self._mode_char} {stopping_threshold}."
203203
)
@@ -207,7 +207,7 @@ def check(self, val: Union[torch.Tensor, float, int]) -> bool:
207207
if divergence_threshold is not None and self._mode_func(
208208
-val, -divergence_threshold
209209
):
210-
_log.debug(
210+
_log.warning(
211211
"Divergence threshold reached:"
212212
f" {val} {self._mode_char} {divergence_threshold}."
213213
)
@@ -222,6 +222,8 @@ def check(self, val: Union[torch.Tensor, float, int]) -> bool:
222222
message = self._improvement_message(val)
223223
self._best_value = val
224224
self._patience_count = 0
225+
_log.debug(message)
226+
225227
else:
226228
# Not improving
227229
self._patience_count += 1
@@ -241,7 +243,8 @@ def check(self, val: Union[torch.Tensor, float, int]) -> bool:
241243
f" {self.patience - self._patience_count} checks of patience remaining."
242244
)
243245

244-
_log.debug(message)
246+
_log.warning(message)
247+
245248
return should_stop
246249

247250
@property

0 commit comments

Comments
 (0)