Skip to content

Commit 863113c

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Allow metric-naive checkpoints for missing or malformed metric values (#939)
Summary: Pull Request resolved: #939 Reviewed By: richardwang-at-fb Differential Revision: D65452995 fbshipit-source-id: 4747dde369c5975e3d40901c4b64492795848aa9
1 parent b467a0e commit 863113c

File tree

2 files changed

+56
-31
lines changed

2 files changed

+56
-31
lines changed

tests/framework/callbacks/test_base_checkpointer.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -760,23 +760,32 @@ def test_keep_last_n_checkpoints_e2e(self) -> None:
760760
)
761761

762762
def test_best_checkpoint_attr_missing(self) -> None:
763-
bcs = BaseCheckpointSaver(
764-
"foo",
765-
save_every_n_epochs=1,
766-
best_checkpoint_config=BestCheckpointConfig(
767-
monitored_metric="train_loss",
768-
mode="min",
769-
),
770-
)
763+
with tempfile.TemporaryDirectory() as temp_dir:
764+
bcs = BaseCheckpointSaver(
765+
temp_dir,
766+
save_every_n_epochs=1,
767+
best_checkpoint_config=BestCheckpointConfig(
768+
monitored_metric="train_loss",
769+
mode="min",
770+
),
771+
)
771772

772-
state = get_dummy_train_state()
773-
my_val_unit = MyValLossUnit()
773+
state = get_dummy_train_state()
774+
my_val_unit = MyValLossUnit()
774775

775-
with self.assertRaisesRegex(
776-
RuntimeError,
777-
"Unit does not have attribute train_loss, unable to retrieve metric to checkpoint.",
778-
):
779-
bcs.on_train_epoch_end(state, my_val_unit)
776+
error_container = []
777+
with patch(
778+
"torchtnt.framework.callbacks.base_checkpointer.logging.Logger.error",
779+
side_effect=error_container.append,
780+
):
781+
bcs.on_train_epoch_end(state, my_val_unit)
782+
783+
self.assertIn(
784+
"Unit does not have attribute train_loss, unable to retrieve metric to checkpoint. Will not be included in checkpoint path, nor tracked for optimality.",
785+
error_container,
786+
)
787+
788+
self.assertTrue(os.path.exists(f"{temp_dir}/epoch_0_train_step_0"))
780789

781790
def test_best_checkpoint_no_top_k(self) -> None:
782791
"""
@@ -1008,15 +1017,20 @@ def test_get_tracked_metric_value(self) -> None:
10081017

10091018
# pyre-ignore
10101019
val_loss_unit.val_loss = "hola" # Test weird metric value
1011-
with self.assertRaisesRegex(
1012-
RuntimeError,
1013-
(
1014-
"Unable to convert monitored metric val_loss to a float. Please ensure the value "
1015-
"can be converted to float and is not a multi-element tensor value."
1016-
),
1020+
error_container = []
1021+
with patch(
1022+
"torchtnt.framework.callbacks.base_checkpointer.logging.Logger.error",
1023+
side_effect=error_container.append,
10171024
):
10181025
val_loss = val_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit)
10191026

1027+
self.assertIn(
1028+
"Unable to convert monitored metric val_loss to a float: could not convert string to float: 'hola'. "
1029+
"Please ensure the value can be converted to float and is not a multi-element tensor value. Will not be "
1030+
"included in checkpoint path, nor tracked for optimality.",
1031+
error_container,
1032+
)
1033+
10201034
val_loss_unit.val_loss = float("nan") # Test nan metric value
10211035
error_container = []
10221036
with patch(
@@ -1053,12 +1067,19 @@ def test_get_tracked_metric_value(self) -> None:
10531067
dirpath="checkpoint",
10541068
best_checkpoint_config=BestCheckpointConfig("train_loss", "max"),
10551069
)
1056-
with self.assertRaisesRegex(
1057-
RuntimeError,
1058-
"Unit does not have attribute train_loss, unable to retrieve metric to checkpoint.",
1070+
error_container = []
1071+
with patch(
1072+
"torchtnt.framework.callbacks.base_checkpointer.logging.Logger.error",
1073+
side_effect=error_container.append,
10591074
):
10601075
val_loss = train_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit)
10611076

1077+
self.assertIn(
1078+
"Unit does not have attribute train_loss, unable to retrieve metric to checkpoint. "
1079+
"Will not be included in checkpoint path, nor tracked for optimality.",
1080+
error_container,
1081+
)
1082+
10621083
ckpt_cb = BaseCheckpointSaver(
10631084
dirpath="checkpoint",
10641085
)

torchtnt/framework/callbacks/base_checkpointer.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -285,19 +285,23 @@ def _get_tracked_metric_value(self, unit: TTrainUnit) -> Optional[float]:
285285

286286
monitored_metric_name = self._best_checkpoint_config.monitored_metric
287287
if not hasattr(unit, monitored_metric_name):
288-
raise RuntimeError(
289-
f"Unit does not have attribute {monitored_metric_name}, unable to retrieve metric to checkpoint."
288+
logger.error(
289+
f"Unit does not have attribute {monitored_metric_name}, unable to retrieve metric to checkpoint. "
290+
"Will not be included in checkpoint path, nor tracked for optimality."
290291
)
292+
return None
291293

292294
metric_value_f = None
293295
if (metric_value := getattr(unit, monitored_metric_name)) is not None:
294296
try:
295297
metric_value_f = float(metric_value)
296-
except ValueError as e:
297-
raise RuntimeError(
298-
f"Unable to convert monitored metric {monitored_metric_name} to a float. Please ensure the value "
299-
"can be converted to float and is not a multi-element tensor value."
300-
) from e
298+
except ValueError as exc:
299+
logger.error(
300+
f"Unable to convert monitored metric {monitored_metric_name} to a float: {exc}. Please ensure the value "
301+
"can be converted to float and is not a multi-element tensor value. Will not be included in checkpoint path, "
302+
"nor tracked for optimality."
303+
)
304+
return None
301305

302306
if metric_value_f and math.isnan(metric_value_f):
303307
logger.error(

0 commit comments

Comments
 (0)