Skip to content

Commit 8ee0aa9

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Don't include NaN metric values in ckpt paths (#896)
Summary: Pull Request resolved: #896 Reviewed By: JKSenthil Differential Revision: D62469085 fbshipit-source-id: 746ba7d16390e2cc7fa513961f500317c73bcf06
1 parent 33b98f4 commit 8ee0aa9

File tree

4 files changed

+61
-0
lines changed

4 files changed

+61
-0
lines changed

tests/framework/callbacks/test_base_checkpointer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,23 @@ def test_get_tracked_metric_value(self) -> None:
919919
):
920920
val_loss = val_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit)
921921

922+
val_loss_unit.val_loss = float("nan") # Test nan metric value
923+
error_container = []
924+
with patch(
925+
"torchtnt.framework.callbacks.base_checkpointer.logging.Logger.error",
926+
side_effect=error_container.append,
927+
):
928+
val_loss = val_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit)
929+
930+
self.assertEqual(
931+
[
932+
"Monitored metric 'val_loss' is NaN. Will not be included in checkpoint path, nor tracked for optimality."
933+
],
934+
error_container,
935+
)
936+
self.assertIsNone(val_loss)
937+
938+
# test with mismatched monitored metric
922939
train_loss_ckpt_cb = BaseCheckpointSaver(
923940
dirpath="checkpoint",
924941
best_checkpoint_config=BestCheckpointConfig("train_loss", "max"),

tests/utils/test_checkpoint.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,37 @@
4747

4848

4949
class CheckpointPathTest(unittest.TestCase):
50+
51+
def test_create_checkpoint_path(self) -> None:
52+
# phase-naive and metric-naive
53+
ckpt = CheckpointPath("foo", epoch=0, step=1)
54+
self.assertEqual(ckpt.path, "foo/epoch_0_step_1")
55+
56+
# phase-aware and metric-naive
57+
ckpt = CheckpointPath("foo", epoch=0, step={Phase.TRAIN: 1})
58+
self.assertEqual(ckpt.path, "foo/epoch_0_train_step_1")
59+
60+
# phase-aware and metric-aware
61+
ckpt = CheckpointPath(
62+
"foo",
63+
epoch=0,
64+
step={Phase.TRAIN: 1, Phase.EVALUATE: 1},
65+
metric_data=MetricData("foo", 1.0),
66+
)
67+
self.assertEqual(ckpt.path, "foo/epoch_0_train_step_1_eval_step_1_foo=1.0")
68+
69+
# nan metric value
70+
with self.assertRaisesRegex(
71+
ValueError,
72+
"Value of monitored metric 'foo' can't be NaN in CheckpointPath.",
73+
):
74+
CheckpointPath(
75+
"foo",
76+
epoch=0,
77+
step={Phase.TRAIN: 1, Phase.EVALUATE: 1},
78+
metric_data=MetricData("foo", float("nan")),
79+
)
80+
5081
def test_from_str(self) -> None:
5182
# invalid paths
5283
malformed_paths = [

torchtnt/framework/callbacks/base_checkpointer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import abc
1010
import logging
11+
import math
1112
from datetime import timedelta
1213
from typing import Any, cast, Iterable, List, Literal, Optional, Union
1314

@@ -256,6 +257,12 @@ def _get_tracked_metric_value(
256257
"can be converted to float and is not a multi-element tensor value."
257258
) from e
258259

260+
if metric_value_f and math.isnan(metric_value_f):
261+
logger.error(
262+
f"Monitored metric '{monitored_metric_name}' is NaN. Will not be included in checkpoint path, nor tracked for optimality."
263+
)
264+
return None
265+
259266
return metric_value_f
260267

261268
def on_train_start(self, state: State, unit: TTrainUnit) -> None:

torchtnt/utils/checkpoint.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88
import bisect
99
import logging
10+
import math
1011
import os
1112
import re
1213
from dataclasses import dataclass
@@ -105,6 +106,11 @@ def __init__(
105106
step if isinstance(step, dict) else {Phase.NONE: step}
106107
)
107108

109+
if metric_data and math.isnan(metric_data.value):
110+
raise ValueError(
111+
f"Value of monitored metric '{metric_data.name}' can't be NaN in CheckpointPath."
112+
)
113+
108114
@classmethod
109115
def from_str(cls, checkpoint_path: str) -> "CheckpointPath":
110116
"""

0 commit comments

Comments
 (0)