15
15
import torch .distributed as dist
16
16
from pyre_extensions import none_throws
17
17
from torchtnt .framework .callback import Callback
18
- from torchtnt .framework .callbacks ._checkpoint_utils import _get_step_phase_mapping
18
+ from torchtnt .framework .callbacks ._checkpoint_utils import (
19
+ _get_epoch ,
20
+ _get_step_phase_mapping ,
21
+ )
19
22
from torchtnt .framework .callbacks .checkpointer_types import RestoreOptions
20
- from torchtnt .framework .state import EntryPoint , State
21
- from torchtnt .framework .unit import AppStateMixin , TEvalUnit , TTrainData , TTrainUnit
23
+ from torchtnt .framework .state import ActivePhase , EntryPoint , State
24
+ from torchtnt .framework .unit import (
25
+ AppStateMixin ,
26
+ TEvalUnit ,
27
+ TPredictUnit ,
28
+ TTrainData ,
29
+ TTrainUnit ,
30
+ )
22
31
from torchtnt .utils .checkpoint import (
23
32
BestCheckpointConfig ,
24
33
CheckpointManager ,
@@ -51,8 +60,11 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta):
51
60
save_every_n_train_steps: Frequency of steps with which to save checkpoints during the train epoch. If None, no intra-epoch checkpoints are generated.
52
61
save_every_n_epochs: Frequency of epochs with which to save checkpoints during training. If None, no end-of-epoch checkpoints are generated.
53
62
save_every_n_eval_epochs: Frequency of evaluation epochs with which to save checkpoints during training. Use this if wanting to save checkpoints after every eval epoch during fit.
54
- keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead.
55
- best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint.
63
+ save_every_n_eval_steps: Frequency of evaluation steps with which to save checkpoints during training. Use this if wanting to save checkpoints during evaluate.
64
+ save_every_n_predict_steps: Frequency of prediction steps with which to save checkpoints during training. Use this if wanting to save checkpoints during using predict entrypoint.
65
+ keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted
66
+ to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead. Only supported for train or fit entrypoints.
67
+ best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint. This param is ignored if not in train or fit entrypoints.
56
68
process_group: The process group on which the ranks will communicate on. If the process group is not gloo-based, a new gloo-based process group will be created.
57
69
58
70
Note:
@@ -78,6 +90,8 @@ def __init__(
78
90
save_every_n_train_steps : Optional [int ] = None ,
79
91
save_every_n_epochs : Optional [int ] = None ,
80
92
save_every_n_eval_epochs : Optional [int ] = None ,
93
+ save_every_n_eval_steps : Optional [int ] = None ,
94
+ save_every_n_predict_steps : Optional [int ] = None ,
81
95
keep_last_n_checkpoints : Optional [int ] = None ,
82
96
best_checkpoint_config : Optional [BestCheckpointConfig ] = None ,
83
97
process_group : Optional [dist .ProcessGroup ] = None ,
@@ -90,12 +104,23 @@ def __init__(
90
104
raise ValueError (
91
105
f"Invalid value passed for save_every_n_epochs. Expected to receive either None or positive number, but received { save_every_n_epochs } "
92
106
)
107
+ if save_every_n_eval_steps is not None and save_every_n_eval_steps <= 0 :
108
+ raise ValueError (
109
+ f"Invalid value passed for save_every_n_eval_steps. Expected to receive either None or positive number, but received { save_every_n_eval_steps } "
110
+ )
111
+ if save_every_n_eval_epochs is not None and save_every_n_eval_epochs <= 0 :
112
+ raise ValueError (
113
+ f"Invalid value passed for save_every_n_eval_epochs. Expected to receive either None or positive number, but received { save_every_n_eval_epochs } "
114
+ )
115
+ if save_every_n_predict_steps is not None and save_every_n_predict_steps <= 0 :
116
+ raise ValueError (
117
+ f"Invalid value passed for save_every_n_predict_steps. Expected to receive either None or positive number, but received { save_every_n_predict_steps } "
118
+ )
93
119
if keep_last_n_checkpoints is not None and keep_last_n_checkpoints <= 0 :
94
120
raise ValueError (
95
121
f"Invalid value passed for keep_last_n_checkpoints. Expected to receive either None or positive number, but received { keep_last_n_checkpoints } "
96
122
)
97
123
98
- self ._best_checkpoint_config = best_checkpoint_config
99
124
if best_checkpoint_config and best_checkpoint_config .mode not in {"min" , "max" }:
100
125
raise ValueError (
101
126
f"Invalid value passed for best_checkpoint_config.mode. Expected to receive 'min' or 'max', but received { best_checkpoint_config .mode } "
@@ -104,7 +129,10 @@ def __init__(
104
129
self ._save_every_n_train_steps = save_every_n_train_steps
105
130
self ._save_every_n_epochs = save_every_n_epochs
106
131
self ._save_every_n_eval_epochs = save_every_n_eval_epochs
132
+ self ._save_every_n_eval_steps = save_every_n_eval_steps
133
+ self ._save_every_n_predict_steps = save_every_n_predict_steps
107
134
self ._keep_last_n_checkpoints = keep_last_n_checkpoints
135
+ self ._best_checkpoint_config = best_checkpoint_config
108
136
109
137
self ._process_group : Optional [dist .ProcessGroup ] = None
110
138
self ._setup_gloo_pg (process_group )
@@ -147,7 +175,7 @@ def dirpath(self) -> str:
147
175
return self ._checkpoint_manager .dirpath
148
176
149
177
def _generate_checkpoint_and_upkeep (
150
- self , state : State , unit : Union [TTrainUnit , TEvalUnit ], hook : str
178
+ self , state : State , unit : Union [TTrainUnit , TEvalUnit , TPredictUnit ], hook : str
151
179
) -> bool :
152
180
"""
153
181
Implementation for saving checkpoint while taking care of checkpoint
@@ -162,11 +190,16 @@ def _generate_checkpoint_and_upkeep(
162
190
True if checkpoint was successfully saved. False otherwise.
163
191
"""
164
192
# 1) generate checkpoint name
165
- epoch = cast ( TTrainUnit , unit ). train_progress . num_epochs_completed
193
+ epoch = _get_epoch ( state , unit )
166
194
step_mapping = _get_step_phase_mapping (state , unit )
167
195
196
+ # 1.1) append metric data only for train checkpoints, if best_checkpoint_config is defined
168
197
metric_data : Optional [MetricData ] = None
169
- if metric_value := self ._get_tracked_metric_value (unit ):
198
+ if (
199
+ self ._best_checkpoint_config
200
+ and state .active_phase == ActivePhase .TRAIN
201
+ and (metric_value := self ._get_tracked_metric_value (cast (TTrainUnit , unit )))
202
+ ):
170
203
metric_data = MetricData (
171
204
name = none_throws (self ._best_checkpoint_config ).monitored_metric ,
172
205
value = metric_value ,
@@ -179,7 +212,8 @@ def _generate_checkpoint_and_upkeep(
179
212
process_group = self ._process_group ,
180
213
)
181
214
182
- # 2) Determine if we should save checkpoint
215
+ # 2) Determine if we should save checkpoint. This is a no-op for eval and predict entrypoints
216
+ # since neither best_checkpoint_config nor keep_last_n_checkpoints are supported.
183
217
if not self ._checkpoint_manager .should_save_checkpoint (checkpoint_path ):
184
218
return False
185
219
@@ -222,9 +256,7 @@ def _generate_checkpoint_and_upkeep(
222
256
223
257
return True
224
258
225
- def _get_tracked_metric_value (
226
- self , unit : Union [TTrainUnit , TEvalUnit ]
227
- ) -> Optional [float ]:
259
+ def _get_tracked_metric_value (self , unit : TTrainUnit ) -> Optional [float ]:
228
260
"""
229
261
If the checkpointer has a tracked metric, look the value in the unit using reflection, and cast to float.
230
262
@@ -271,33 +303,80 @@ def on_train_start(self, state: State, unit: TTrainUnit) -> None:
271
303
272
304
def on_train_step_end (self , state : State , unit : TTrainUnit ) -> None :
273
305
num_steps_completed = unit .train_progress .num_steps_completed
274
- save_every_n_train_steps = self ._save_every_n_train_steps
275
306
if (
276
- save_every_n_train_steps is None
277
- or num_steps_completed % save_every_n_train_steps != 0
307
+ not self . _save_every_n_train_steps
308
+ or num_steps_completed % self . _save_every_n_train_steps != 0
278
309
):
279
310
return
280
311
281
312
self ._generate_checkpoint_and_upkeep (state , unit , hook = "on_train_step_end" )
282
313
283
314
def on_train_epoch_end (self , state : State , unit : TTrainUnit ) -> None :
284
315
epoch = unit .train_progress .num_epochs_completed
285
- save_every_n_epochs = self ._save_every_n_epochs
286
- if save_every_n_epochs is None or epoch % save_every_n_epochs != 0 :
316
+ if not self ._save_every_n_epochs or epoch % self ._save_every_n_epochs != 0 :
287
317
return
288
318
289
319
self ._generate_checkpoint_and_upkeep (state , unit , hook = "on_train_epoch_end" )
290
320
321
+ def on_train_end (self , state : State , unit : TTrainUnit ) -> None :
322
+ self ._generate_checkpoint_and_upkeep (state , unit , hook = "on_train_end" )
323
+
324
+ def on_eval_start (self , state : State , unit : TEvalUnit ) -> None :
325
+ if state .entry_point == EntryPoint .EVALUATE :
326
+ self ._disable_ckpt_optimality_tracking ()
327
+
328
+ def on_eval_step_end (self , state : State , unit : TEvalUnit ) -> None :
329
+ num_steps_completed = unit .eval_progress .num_steps_completed
330
+ if (
331
+ not self ._save_every_n_eval_steps
332
+ or num_steps_completed % self ._save_every_n_eval_steps != 0
333
+ ):
334
+ return
335
+
336
+ self ._generate_checkpoint_and_upkeep (state , unit , hook = "on_eval_step_end" )
337
+
291
338
def on_eval_epoch_end (self , state : State , unit : TEvalUnit ) -> None :
292
339
epoch = unit .eval_progress .num_epochs_completed
293
- save_every_n_eval_epochs = self ._save_every_n_eval_epochs
294
- if save_every_n_eval_epochs is None or epoch % save_every_n_eval_epochs != 0 :
340
+ if (
341
+ not self ._save_every_n_eval_epochs
342
+ or epoch % self ._save_every_n_eval_epochs != 0
343
+ ):
295
344
return
296
345
297
346
self ._generate_checkpoint_and_upkeep (state , unit , hook = "on_eval_epoch_end" )
298
347
299
- def on_train_end (self , state : State , unit : TTrainUnit ) -> None :
300
- self ._generate_checkpoint_and_upkeep (state , unit , hook = "on_train_end" )
348
+ def on_predict_start (self , state : State , unit : TPredictUnit ) -> None :
349
+ self ._disable_ckpt_optimality_tracking ()
350
+
351
+ def on_predict_step_end (self , state : State , unit : TPredictUnit ) -> None :
352
+ num_steps_completed = unit .predict_progress .num_steps_completed
353
+ if (
354
+ not self ._save_every_n_predict_steps
355
+ or num_steps_completed % self ._save_every_n_predict_steps != 0
356
+ ):
357
+ return
358
+
359
+ self ._generate_checkpoint_and_upkeep (state , unit , hook = "on_predict_step_end" )
360
+
361
+ def _disable_ckpt_optimality_tracking (self ) -> None :
362
+ """
363
+ Disables checkpoint optimality tracking. This means that best_checkpoint and keep_last_n_checkpoints
364
+ will not be used. This is useful for eval and predict entrypoints, since checkpoints do not include
365
+ model parameters.
366
+ """
367
+ if self ._best_checkpoint_config :
368
+ logger .warning (
369
+ "Disabling best_checkpoint_config, since it is not supported for eval or predict entrypoints."
370
+ )
371
+ self ._best_checkpoint_config = None
372
+ self ._checkpoint_manager ._best_checkpoint_config = None
373
+
374
+ if self ._keep_last_n_checkpoints :
375
+ logger .warning (
376
+ "Disabling keep_last_n_checkpoints, since is not supported for eval or predict entrypoints."
377
+ )
378
+ self ._keep_last_n_checkpoints = None
379
+ self ._checkpoint_manager ._keep_last_n_checkpoints = None
301
380
302
381
@abc .abstractmethod
303
382
def _checkpoint_impl (
0 commit comments