Skip to content

Commit 6e6824c

Browse files
Pavel Levinfacebook-github-bot
authored andcommitted
Early stopper inconsistent devices fix (#949)
Summary: Pull Request resolved: #949 `val` and `self._best_value` can have inconsitent devices in multi-GPU trainings which will fail at early stopper checks Reviewed By: JKSenthil Differential Revision: D66160768 fbshipit-source-id: 7f80900343bbdb80b118052452156a4fd5b67b73
1 parent 1fe0a5d commit 6e6824c

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

torchtnt/utils/early_stop_checker.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing_extensions import final, Literal
1515

1616
_log: logging.Logger = logging.getLogger(__name__)
17+
_log.setLevel(logging.DEBUG)
1718

1819

1920
@final
@@ -179,11 +180,13 @@ def check(self, val: Union[torch.Tensor, float, int]) -> bool:
179180
divergence_threshold = divergence_threshold.to(val.device)
180181
improvement_threshold = self.min_delta
181182
if self._threshold_mode == "rel":
182-
base_val = self._best_value if torch.isfinite(self._best_value) else 0.0
183+
base_val = (
184+
self._best_value.to(val.device)
185+
if torch.isfinite(self._best_value)
186+
else 0.0
187+
)
183188
improvement_threshold = self.min_delta.to(val.device) * base_val
184189

185-
improvement_threshold = improvement_threshold.to(val.device)
186-
187190
# Check finite
188191
if self.check_finite and not torch.isfinite(val):
189192
_log.debug(
@@ -212,7 +215,7 @@ def check(self, val: Union[torch.Tensor, float, int]) -> bool:
212215

213216
# Check if improvement is happening
214217
if self._mode_func(
215-
val - improvement_threshold, self._best_value.to(val.device)
218+
val - improvement_threshold.to(val.device), self._best_value.to(val.device)
216219
):
217220
# Still improving
218221
should_stop = False
@@ -259,9 +262,12 @@ def _improvement_message(self, val: torch.Tensor) -> str:
259262
"""Formats a log message that informs the user about an improvement in the monitored score."""
260263
if torch.isfinite(self._best_value):
261264
improvement = (
262-
torch.abs(self._best_value - val)
265+
torch.abs(self._best_value.to(val.device) - val)
263266
if self.threshold_mode == "abs"
264-
else torch.abs((self._best_value - val) / (1.0 * self._best_value))
267+
else torch.abs(
268+
(self._best_value.to(val.device) - val)
269+
/ (1.0 * self._best_value.to(val.device))
270+
)
265271
)
266272
msg = (
267273
f"Metric improved by {self.threshold_mode} {improvement} >="

0 commit comments

Comments
 (0)