Skip to content

Commit ba2fb54

Browse files
alanhdufacebook-github-bot
authored andcommitted
Handle scientific notation in metric values (#953)
Summary: Pull Request resolved: #953 Previously, this code would crash if the metric value had scientific notation (like `6.486097566010406e+18`), which can cause job crashes (like https://www.internalfb.com/mlhub/pipelines/runs/mast/f676386628-alandu-hpo_mast_3342b13186?job_attempt=7&version=0&tab=debug&env=PRODUCTION). This updates the regex to handle scientific notation and adds a test case. Reviewed By: diego-urgell, JKSenthil Differential Revision: D67541312 fbshipit-source-id: d4bc637a9a0f93323cf55877025844f9b1651428
1 parent 2e762a1 commit ba2fb54

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

tests/utils/test_checkpoint.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,15 @@ def test_from_str(self) -> None:
251251
metric_data=MetricData("mean_loss_squared", 0.0),
252252
),
253253
),
254+
(
255+
"foo/bar/epoch_1_train_step_2_eval_step_3_eval_loss=6.486097566010406e+18",
256+
CheckpointPath(
257+
"foo/bar/",
258+
epoch=1,
259+
step={Phase.TRAIN: 2, Phase.EVALUATE: 3},
260+
metric_data=MetricData("eval_loss", 6.486097566010406e18),
261+
),
262+
),
254263
]
255264
for path, expected_ckpt in valid_paths:
256265
parsed_ckpt = CheckpointPath.from_str(path)

torchtnt/utils/checkpoint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,14 @@ class CheckpointPath:
8080
- phase-aware and metric-aware- <dirpath>/epoch_<epoch>_train_step_<train_step>_eval_step_<eval_step>_<metric_name>=<metric_value>
8181
"""
8282

83+
_FLOAT_REGEX: str = r"-?\d+\.?\d*(?:e[\+\-]\d+)?"
84+
8385
PHASE_NAIVE_REGEX: Pattern = re.compile(
84-
r"^(.+)epoch_(\d+)_step_(\d+)(?:_(\w+)=(-?\d+\.?\d*))?\/?$"
86+
rf"^(.+)epoch_(\d+)_step_(\d+)(?:_(\w+)=({_FLOAT_REGEX}))?\/?$"
8587
)
8688

8789
PHASE_AWARE_REGEX: Pattern = re.compile(
88-
r"^(.+)epoch_(\d+)(?:_train_step_(\d+))?(?:_eval_step_(\d+))?(?:_predict_step_(\d+))?(?:_(\w+)=(-?\d+\.?\d*))?\/?$"
90+
rf"^(.+)epoch_(\d+)(?:_train_step_(\d+))?(?:_eval_step_(\d+))?(?:_predict_step_(\d+))?(?:_(\w+)=({_FLOAT_REGEX}))?\/?$"
8991
)
9092

9193
def __init__(

0 commit comments

Comments
 (0)