@@ -760,23 +760,32 @@ def test_keep_last_n_checkpoints_e2e(self) -> None:
760
760
)
761
761
762
762
def test_best_checkpoint_attr_missing (self ) -> None :
763
- bcs = BaseCheckpointSaver (
764
- "foo" ,
765
- save_every_n_epochs = 1 ,
766
- best_checkpoint_config = BestCheckpointConfig (
767
- monitored_metric = "train_loss" ,
768
- mode = "min" ,
769
- ),
770
- )
763
+ with tempfile .TemporaryDirectory () as temp_dir :
764
+ bcs = BaseCheckpointSaver (
765
+ temp_dir ,
766
+ save_every_n_epochs = 1 ,
767
+ best_checkpoint_config = BestCheckpointConfig (
768
+ monitored_metric = "train_loss" ,
769
+ mode = "min" ,
770
+ ),
771
+ )
771
772
772
- state = get_dummy_train_state ()
773
- my_val_unit = MyValLossUnit ()
773
+ state = get_dummy_train_state ()
774
+ my_val_unit = MyValLossUnit ()
774
775
775
- with self .assertRaisesRegex (
776
- RuntimeError ,
777
- "Unit does not have attribute train_loss, unable to retrieve metric to checkpoint." ,
778
- ):
779
- bcs .on_train_epoch_end (state , my_val_unit )
776
+ error_container = []
777
+ with patch (
778
+ "torchtnt.framework.callbacks.base_checkpointer.logging.Logger.error" ,
779
+ side_effect = error_container .append ,
780
+ ):
781
+ bcs .on_train_epoch_end (state , my_val_unit )
782
+
783
+ self .assertIn (
784
+ "Unit does not have attribute train_loss, unable to retrieve metric to checkpoint. Will not be included in checkpoint path, nor tracked for optimality." ,
785
+ error_container ,
786
+ )
787
+
788
+ self .assertTrue (os .path .exists (f"{ temp_dir } /epoch_0_train_step_0" ))
780
789
781
790
def test_best_checkpoint_no_top_k (self ) -> None :
782
791
"""
@@ -1008,15 +1017,20 @@ def test_get_tracked_metric_value(self) -> None:
1008
1017
1009
1018
# pyre-ignore
1010
1019
val_loss_unit .val_loss = "hola" # Test weird metric value
1011
- with self .assertRaisesRegex (
1012
- RuntimeError ,
1013
- (
1014
- "Unable to convert monitored metric val_loss to a float. Please ensure the value "
1015
- "can be converted to float and is not a multi-element tensor value."
1016
- ),
1020
+ error_container = []
1021
+ with patch (
1022
+ "torchtnt.framework.callbacks.base_checkpointer.logging.Logger.error" ,
1023
+ side_effect = error_container .append ,
1017
1024
):
1018
1025
val_loss = val_loss_ckpt_cb ._get_tracked_metric_value (val_loss_unit )
1019
1026
1027
+ self .assertIn (
1028
+ "Unable to convert monitored metric val_loss to a float: could not convert string to float: 'hola'. "
1029
+ "Please ensure the value can be converted to float and is not a multi-element tensor value. Will not be "
1030
+ "included in checkpoint path, nor tracked for optimality." ,
1031
+ error_container ,
1032
+ )
1033
+
1020
1034
val_loss_unit .val_loss = float ("nan" ) # Test nan metric value
1021
1035
error_container = []
1022
1036
with patch (
@@ -1053,12 +1067,19 @@ def test_get_tracked_metric_value(self) -> None:
1053
1067
dirpath = "checkpoint" ,
1054
1068
best_checkpoint_config = BestCheckpointConfig ("train_loss" , "max" ),
1055
1069
)
1056
- with self .assertRaisesRegex (
1057
- RuntimeError ,
1058
- "Unit does not have attribute train_loss, unable to retrieve metric to checkpoint." ,
1070
+ error_container = []
1071
+ with patch (
1072
+ "torchtnt.framework.callbacks.base_checkpointer.logging.Logger.error" ,
1073
+ side_effect = error_container .append ,
1059
1074
):
1060
1075
val_loss = train_loss_ckpt_cb ._get_tracked_metric_value (val_loss_unit )
1061
1076
1077
+ self .assertIn (
1078
+ "Unit does not have attribute train_loss, unable to retrieve metric to checkpoint. "
1079
+ "Will not be included in checkpoint path, nor tracked for optimality." ,
1080
+ error_container ,
1081
+ )
1082
+
1062
1083
ckpt_cb = BaseCheckpointSaver (
1063
1084
dirpath = "checkpoint" ,
1064
1085
)
0 commit comments