|
14 | 14 | from typing_extensions import final, Literal
|
15 | 15 |
|
16 | 16 | _log: logging.Logger = logging.getLogger(__name__)
|
| 17 | +_log.setLevel(logging.DEBUG) |
17 | 18 |
|
18 | 19 |
|
19 | 20 | @final
|
@@ -179,11 +180,13 @@ def check(self, val: Union[torch.Tensor, float, int]) -> bool:
|
179 | 180 | divergence_threshold = divergence_threshold.to(val.device)
|
180 | 181 | improvement_threshold = self.min_delta
|
181 | 182 | if self._threshold_mode == "rel":
|
182 |
| - base_val = self._best_value if torch.isfinite(self._best_value) else 0.0 |
| 183 | + base_val = ( |
| 184 | + self._best_value.to(val.device) |
| 185 | + if torch.isfinite(self._best_value) |
| 186 | + else 0.0 |
| 187 | + ) |
183 | 188 | improvement_threshold = self.min_delta.to(val.device) * base_val
|
184 | 189 |
|
185 |
| - improvement_threshold = improvement_threshold.to(val.device) |
186 |
| - |
187 | 190 | # Check finite
|
188 | 191 | if self.check_finite and not torch.isfinite(val):
|
189 | 192 | _log.debug(
|
@@ -212,7 +215,7 @@ def check(self, val: Union[torch.Tensor, float, int]) -> bool:
|
212 | 215 |
|
213 | 216 | # Check if improvement is happening
|
214 | 217 | if self._mode_func(
|
215 |
| - val - improvement_threshold, self._best_value.to(val.device) |
| 218 | + val - improvement_threshold.to(val.device), self._best_value.to(val.device) |
216 | 219 | ):
|
217 | 220 | # Still improving
|
218 | 221 | should_stop = False
|
@@ -259,9 +262,12 @@ def _improvement_message(self, val: torch.Tensor) -> str:
|
259 | 262 | """Formats a log message that informs the user about an improvement in the monitored score."""
|
260 | 263 | if torch.isfinite(self._best_value):
|
261 | 264 | improvement = (
|
262 |
| - torch.abs(self._best_value - val) |
| 265 | + torch.abs(self._best_value.to(val.device) - val) |
263 | 266 | if self.threshold_mode == "abs"
|
264 |
| - else torch.abs((self._best_value - val) / (1.0 * self._best_value)) |
| 267 | + else torch.abs( |
| 268 | + (self._best_value.to(val.device) - val) |
| 269 | + / (1.0 * self._best_value.to(val.device)) |
| 270 | + ) |
265 | 271 | )
|
266 | 272 | msg = (
|
267 | 273 | f"Metric improved by {self.threshold_mode} {improvement} >="
|
|
0 commit comments