Skip to content

Commit 1064d10

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Enable creation of predict checkpoint paths (#907)
Summary: Pull Request resolved: #907 Reviewed By: anshulverma, JKSenthil Differential Revision: D63013010 fbshipit-source-id: f7e872ebd4b65d74f312a0dd25d220c4931af658
1 parent a577dd4 commit 1064d10

File tree

2 files changed

+86
-9
lines changed

2 files changed

+86
-9
lines changed

tests/utils/test_checkpoint.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,32 @@ def test_create_checkpoint_path(self) -> None:
6666
)
6767
self.assertEqual(ckpt.path, "foo/epoch_0_train_step_1_eval_step_1_foo=1.0")
6868

69+
# evaluation only
70+
ckpt = CheckpointPath(
71+
"foo",
72+
epoch=0,
73+
step={Phase.EVALUATE: 1},
74+
)
75+
self.assertEqual(ckpt.path, "foo/epoch_0_eval_step_1")
76+
77+
# prediction only
78+
ckpt = CheckpointPath(
79+
"foo",
80+
epoch=0,
81+
step={Phase.PREDICT: 1},
82+
)
83+
self.assertEqual(ckpt.path, "foo/epoch_0_predict_step_1")
84+
85+
# all phases - not expected but should work
86+
ckpt = CheckpointPath(
87+
"foo",
88+
epoch=0,
89+
step={Phase.TRAIN: 1, Phase.EVALUATE: 1, Phase.PREDICT: 1},
90+
)
91+
self.assertEqual(
92+
ckpt.path, "foo/epoch_0_train_step_1_eval_step_1_predict_step_1"
93+
)
94+
6995
# nan metric value
7096
with self.assertRaisesRegex(
7197
ValueError,
@@ -90,6 +116,8 @@ def test_from_str(self) -> None:
90116
"foo/epoch_2.6_step_23",
91117
"foo/epoch_3_pred_step_3",
92118
"foo/epoch_3__step_3",
119+
"foo/epoch_2_predict_step_2_eval_step_1",
120+
"foo/epoch_2_predict_step_3.2",
93121
]
94122
for path in malformed_paths:
95123
with self.assertRaisesRegex(
@@ -110,6 +138,15 @@ def test_from_str(self) -> None:
110138
"foo", epoch=14, step=3, metric_data=MetricData("mean", 15.0)
111139
),
112140
),
141+
(
142+
"foo/epoch_14_step_3_train_loss=15.0",
143+
CheckpointPath(
144+
"foo",
145+
epoch=14,
146+
step={Phase.NONE: 3},
147+
metric_data=MetricData("train_loss", 15.0),
148+
),
149+
),
113150
(
114151
"foo/epoch_14_step_3_loss=-27.35",
115152
CheckpointPath(
@@ -122,6 +159,23 @@ def test_from_str(self) -> None:
122159
"/foo", epoch=14, step=3, metric_data=MetricData("loss", -27.35)
123160
),
124161
),
162+
(
163+
"foo/epoch_2_eval_step_23",
164+
CheckpointPath("foo", epoch=2, step={Phase.EVALUATE: 23}),
165+
),
166+
(
167+
"foo/epoch_14_predict_step_5",
168+
CheckpointPath("foo", epoch=14, step={Phase.PREDICT: 5}),
169+
),
170+
(
171+
"foo/epoch_14_train_step_3_eval_loss=0.1",
172+
CheckpointPath(
173+
"foo",
174+
epoch=14,
175+
step={Phase.TRAIN: 3},
176+
metric_data=MetricData("eval_loss", 0.1),
177+
),
178+
),
125179
(
126180
"foo/bar/epoch_23_step_31_mean_loss_squared=0.0",
127181
CheckpointPath(
@@ -266,6 +320,12 @@ def test_compare_by_recency(self) -> None:
266320
self.assertTrue(eval_only < multiphase_2)
267321
self.assertTrue(multiphase_2 < multiphase_3)
268322

323+
predict_1 = CheckpointPath("foo", epoch=3, step={Phase.PREDICT: 10})
324+
predict_2 = CheckpointPath("foo", epoch=4, step={Phase.PREDICT: 10})
325+
predict_3 = CheckpointPath("foo", epoch=4, step={Phase.PREDICT: 20})
326+
self.assertTrue(predict_1 < predict_2)
327+
self.assertTrue(predict_2 < predict_3)
328+
269329
def test_compare_by_optimality(self) -> None:
270330
# not both metric aware
271331
ckpt1 = CheckpointPath("foo", epoch=0, step=1)

torchtnt/utils/checkpoint.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class Phase(Enum):
5353
NONE = 0 # Only used for backwards compatibility
5454
TRAIN = 1
5555
EVALUATE = 2
56+
PREDICT = 3
5657

5758

5859
@total_ordering
@@ -81,7 +82,7 @@ class CheckpointPath:
8182
)
8283

8384
PHASE_AWARE_REGEX: Pattern = re.compile(
84-
r"^(.+)epoch_(\d+)(?:_train_step_(\d+))?(?:_eval_step_(\d+))?(?:_(\w+)=(-?\d+\.?\d*))?\/?$"
85+
r"^(.+)epoch_(\d+)(?:_train_step_(\d+))?(?:_eval_step_(\d+))?(?:_predict_step_(\d+))?(?:_(\w+)=(-?\d+\.?\d*))?\/?$"
8586
)
8687

8788
def __init__(
@@ -142,8 +143,9 @@ def _populate_from_str(self, checkpoint_path: str) -> None:
142143
Raises:
143144
ValueError: If the path is malformed (either non-parsable, or contains wrong data types)
144145
"""
145-
is_phase_aware = (
146-
"train_step" in checkpoint_path or "eval_step" in checkpoint_path
146+
is_phase_aware = any(
147+
phase in checkpoint_path
148+
for phase in ["train_step", "eval_step", "predict_step"]
147149
)
148150
regex = self.PHASE_AWARE_REGEX if is_phase_aware else self.PHASE_NAIVE_REGEX
149151
path_match = regex.match(checkpoint_path)
@@ -155,13 +157,22 @@ def _populate_from_str(self, checkpoint_path: str) -> None:
155157
try:
156158
step_mapping: Dict[Phase, int] = {}
157159
if is_phase_aware:
158-
dirpath, epoch, train_steps, eval_steps, metric_name, metric_value = (
159-
path_match.groups()
160-
)
160+
(
161+
dirpath,
162+
epoch,
163+
train_steps,
164+
eval_steps,
165+
predict_steps,
166+
metric_name,
167+
metric_value,
168+
) = path_match.groups()
169+
161170
if train_steps is not None:
162171
step_mapping[Phase.TRAIN] = int(train_steps)
163172
if eval_steps is not None:
164173
step_mapping[Phase.EVALUATE] = int(eval_steps)
174+
if predict_steps is not None:
175+
step_mapping[Phase.PREDICT] = int(predict_steps)
165176

166177
else:
167178
dirpath, epoch, naive_steps, metric_name, metric_value = (
@@ -200,6 +211,8 @@ def path(self) -> str:
200211
name += f"_train_step_{self.step[Phase.TRAIN]}"
201212
if Phase.EVALUATE in self.step:
202213
name += f"_eval_step_{self.step[Phase.EVALUATE]}"
214+
if Phase.PREDICT in self.step:
215+
name += f"_predict_step_{self.step[Phase.PREDICT]}"
203216

204217
if self.metric_data:
205218
name += f"_{self.metric_data.name}={self.metric_data.value}"
@@ -240,9 +253,13 @@ def newer_than(self, other: "CheckpointPath") -> bool:
240253
# Otherwise, compare first by eval and then train steps
241254
return self._get_phase_steps() > other._get_phase_steps()
242255

243-
def _get_phase_steps(self) -> Tuple[int, int]:
244-
"""Tuple with the phase steps ordered by phase priority in comparison (first eval, then train)."""
245-
return self.step.get(Phase.EVALUATE, 0), self.step.get(Phase.TRAIN, 0)
256+
def _get_phase_steps(self) -> Tuple[int, ...]:
257+
"""Tuple with the phase steps ordered by phase priority in comparison (predict, eval, train)."""
258+
return (
259+
self.step.get(Phase.PREDICT, 0),
260+
self.step.get(Phase.EVALUATE, 0),
261+
self.step.get(Phase.TRAIN, 0),
262+
)
246263

247264
def more_optimal_than(
248265
self, other: "CheckpointPath", mode: Literal["min", "max"]

0 commit comments

Comments
 (0)