@@ -53,6 +53,7 @@ class Phase(Enum):
53
53
NONE = 0 # Only used for backwards compatibility
54
54
TRAIN = 1
55
55
EVALUATE = 2
56
+ PREDICT = 3
56
57
57
58
58
59
@total_ordering
@@ -81,7 +82,7 @@ class CheckpointPath:
81
82
)
82
83
83
84
PHASE_AWARE_REGEX : Pattern = re .compile (
84
- r"^(.+)epoch_(\d+)(?:_train_step_(\d+))?(?:_eval_step_(\d+))?(?:_(\w+)=(-?\d+\.?\d*))?\/?$"
85
+ r"^(.+)epoch_(\d+)(?:_train_step_(\d+))?(?:_eval_step_(\d+))?(?:_predict_step_(\d+))?(?: _(\w+)=(-?\d+\.?\d*))?\/?$"
85
86
)
86
87
87
88
def __init__ (
@@ -142,8 +143,9 @@ def _populate_from_str(self, checkpoint_path: str) -> None:
142
143
Raises:
143
144
ValueError: If the path is malformed (either non-parsable, or contains wrong data types)
144
145
"""
145
- is_phase_aware = (
146
- "train_step" in checkpoint_path or "eval_step" in checkpoint_path
146
+ is_phase_aware = any (
147
+ phase in checkpoint_path
148
+ for phase in ["train_step" , "eval_step" , "predict_step" ]
147
149
)
148
150
regex = self .PHASE_AWARE_REGEX if is_phase_aware else self .PHASE_NAIVE_REGEX
149
151
path_match = regex .match (checkpoint_path )
@@ -155,13 +157,22 @@ def _populate_from_str(self, checkpoint_path: str) -> None:
155
157
try :
156
158
step_mapping : Dict [Phase , int ] = {}
157
159
if is_phase_aware :
158
- dirpath , epoch , train_steps , eval_steps , metric_name , metric_value = (
159
- path_match .groups ()
160
- )
160
+ (
161
+ dirpath ,
162
+ epoch ,
163
+ train_steps ,
164
+ eval_steps ,
165
+ predict_steps ,
166
+ metric_name ,
167
+ metric_value ,
168
+ ) = path_match .groups ()
169
+
161
170
if train_steps is not None :
162
171
step_mapping [Phase .TRAIN ] = int (train_steps )
163
172
if eval_steps is not None :
164
173
step_mapping [Phase .EVALUATE ] = int (eval_steps )
174
+ if predict_steps is not None :
175
+ step_mapping [Phase .PREDICT ] = int (predict_steps )
165
176
166
177
else :
167
178
dirpath , epoch , naive_steps , metric_name , metric_value = (
@@ -200,6 +211,8 @@ def path(self) -> str:
200
211
name += f"_train_step_{ self .step [Phase .TRAIN ]} "
201
212
if Phase .EVALUATE in self .step :
202
213
name += f"_eval_step_{ self .step [Phase .EVALUATE ]} "
214
+ if Phase .PREDICT in self .step :
215
+ name += f"_predict_step_{ self .step [Phase .PREDICT ]} "
203
216
204
217
if self .metric_data :
205
218
name += f"_{ self .metric_data .name } ={ self .metric_data .value } "
@@ -240,9 +253,13 @@ def newer_than(self, other: "CheckpointPath") -> bool:
240
253
# Otherwise, compare first by eval and then train steps
241
254
return self ._get_phase_steps () > other ._get_phase_steps ()
242
255
243
- def _get_phase_steps (self ) -> Tuple [int , int ]:
244
- """Tuple with the phase steps ordered by phase priority in comparison (first eval, then train)."""
245
- return self .step .get (Phase .EVALUATE , 0 ), self .step .get (Phase .TRAIN , 0 )
256
+ def _get_phase_steps (self ) -> Tuple [int , ...]:
257
+ """Tuple with the phase steps ordered by phase priority in comparison (predict, eval, train)."""
258
+ return (
259
+ self .step .get (Phase .PREDICT , 0 ),
260
+ self .step .get (Phase .EVALUATE , 0 ),
261
+ self .step .get (Phase .TRAIN , 0 ),
262
+ )
246
263
247
264
def more_optimal_than (
248
265
self , other : "CheckpointPath" , mode : Literal ["min" , "max" ]
0 commit comments