Skip to content

Commit 84d1be9

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Generate predict/evaluate checkpoints in BaseCheckpointer (#914)
Summary: Pull Request resolved: #914 Reviewed By: JKSenthil Differential Revision: D63013008 fbshipit-source-id: 9721a7fd194b1b380027e5acfe246aa73f69d73b
1 parent f9838b8 commit 84d1be9

File tree

4 files changed

+196
-23
lines changed

4 files changed

+196
-23
lines changed

tests/framework/callbacks/test_base_checkpointer.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Batch,
2626
DummyAutoUnit,
2727
DummyFitUnit,
28+
DummyPredictUnit,
2829
DummyTrainUnit,
2930
generate_random_dataloader,
3031
get_dummy_fit_state,
@@ -35,7 +36,9 @@
3536
)
3637
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
3738
from torchtnt.framework.callbacks.lambda_callback import Lambda
39+
from torchtnt.framework.evaluate import evaluate
3840
from torchtnt.framework.fit import fit
41+
from torchtnt.framework.predict import predict
3942
from torchtnt.framework.state import ActivePhase, State
4043

4144
from torchtnt.framework.train import train
@@ -57,7 +60,9 @@ def __init__(
5760
*,
5861
save_every_n_train_steps: Optional[int] = None,
5962
save_every_n_epochs: Optional[int] = None,
63+
save_every_n_eval_steps: Optional[int] = None,
6064
save_every_n_eval_epochs: Optional[int] = None,
65+
save_every_n_predict_steps: Optional[int] = None,
6166
keep_last_n_checkpoints: Optional[int] = None,
6267
best_checkpoint_config: Optional[BestCheckpointConfig] = None,
6368
process_group: Optional[dist.ProcessGroup] = None,
@@ -66,7 +71,9 @@ def __init__(
6671
dirpath,
6772
save_every_n_train_steps=save_every_n_train_steps,
6873
save_every_n_epochs=save_every_n_epochs,
74+
save_every_n_eval_steps=save_every_n_eval_steps,
6975
save_every_n_eval_epochs=save_every_n_eval_epochs,
76+
save_every_n_predict_steps=save_every_n_predict_steps,
7077
keep_last_n_checkpoints=keep_last_n_checkpoints,
7178
best_checkpoint_config=best_checkpoint_config,
7279
process_group=process_group,
@@ -243,6 +250,83 @@ def test_save_fit_entrypoint(self) -> None:
243250
checkpointer._latest_checkpoint_path,
244251
)
245252

253+
@patch.object(BaseCheckpointSaver, "_checkpoint_impl")
254+
def test_save_eval_entrypoint(self, mock_checkpoint_impl: MagicMock) -> None:
255+
my_unit = DummyFitUnit(input_dim=2)
256+
with tempfile.TemporaryDirectory() as temp_dir:
257+
checkpointer = BaseCheckpointSaver(
258+
temp_dir,
259+
save_every_n_eval_steps=2,
260+
best_checkpoint_config=BestCheckpointConfig(
261+
monitored_metric="val_loss", mode="min"
262+
),
263+
keep_last_n_checkpoints=1,
264+
)
265+
266+
ckpt_container: List[str] = []
267+
268+
def _checkpoint_impl_side_effect(
269+
state: State, unit: AppStateMixin, checkpoint_id: str, hook: str
270+
) -> bool:
271+
ckpt_container.append(checkpoint_id)
272+
return True
273+
274+
mock_checkpoint_impl.side_effect = _checkpoint_impl_side_effect
275+
276+
eval_dataloader = generate_random_dataloader(10, 2, 1)
277+
278+
warning_container: List[str] = []
279+
with patch(
280+
"torchtnt.framework.callbacks.base_checkpointer.logging.Logger.warning",
281+
side_effect=warning_container.append,
282+
):
283+
evaluate(my_unit, eval_dataloader, callbacks=[checkpointer])
284+
285+
# Verify that checkpoint optimality tracking was disabled
286+
self.assertIn(
287+
"Disabling best_checkpoint_config, since it is not supported for eval or predict entrypoints.",
288+
warning_container,
289+
)
290+
self.assertIn(
291+
"Disabling keep_last_n_checkpoints, since is not supported for eval or predict entrypoints.",
292+
warning_container,
293+
)
294+
295+
# Make sure that the correct checkpoints were saved, without tracked metrics
296+
expected_ckpts = [
297+
f"{temp_dir}/epoch_0_eval_step_{i*2}" for i in range(1, 6)
298+
]
299+
self.assertEqual(ckpt_container, expected_ckpts)
300+
301+
@patch.object(BaseCheckpointSaver, "_checkpoint_impl")
302+
def test_save_predict_entrypoint(self, mock_checkpoint_impl: MagicMock) -> None:
303+
my_unit = DummyPredictUnit(input_dim=2)
304+
with tempfile.TemporaryDirectory() as temp_dir:
305+
checkpointer = BaseCheckpointSaver(
306+
temp_dir,
307+
save_every_n_predict_steps=1,
308+
)
309+
310+
ckpt_container: List[str] = []
311+
312+
def _checkpoint_impl_side_effect(
313+
state: State, unit: AppStateMixin, checkpoint_id: str, hook: str
314+
) -> bool:
315+
ckpt_container.append(checkpoint_id)
316+
return True
317+
318+
mock_checkpoint_impl.side_effect = _checkpoint_impl_side_effect
319+
320+
predict_dataloader = generate_random_dataloader(10, 2, 1)
321+
322+
predict(my_unit, predict_dataloader, callbacks=[checkpointer])
323+
324+
# Make sure that the correct checkpoints were saved
325+
expected_ckpts = [
326+
f"{temp_dir}/epoch_0_predict_step_{i}" for i in range(1, 11)
327+
]
328+
self.assertEqual(ckpt_container, expected_ckpts)
329+
246330
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
247331
def test_restore_from_latest(self, mock_stdout: MagicMock) -> None:
248332
my_unit = DummyTrainUnit(input_dim=2)

torchtnt/framework/callbacks/base_checkpointer.py

Lines changed: 101 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,19 @@
1515
import torch.distributed as dist
1616
from pyre_extensions import none_throws
1717
from torchtnt.framework.callback import Callback
18-
from torchtnt.framework.callbacks._checkpoint_utils import _get_step_phase_mapping
18+
from torchtnt.framework.callbacks._checkpoint_utils import (
19+
_get_epoch,
20+
_get_step_phase_mapping,
21+
)
1922
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
20-
from torchtnt.framework.state import EntryPoint, State
21-
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit
23+
from torchtnt.framework.state import ActivePhase, EntryPoint, State
24+
from torchtnt.framework.unit import (
25+
AppStateMixin,
26+
TEvalUnit,
27+
TPredictUnit,
28+
TTrainData,
29+
TTrainUnit,
30+
)
2231
from torchtnt.utils.checkpoint import (
2332
BestCheckpointConfig,
2433
CheckpointManager,
@@ -51,8 +60,11 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta):
5160
save_every_n_train_steps: Frequency of steps with which to save checkpoints during the train epoch. If None, no intra-epoch checkpoints are generated.
5261
save_every_n_epochs: Frequency of epochs with which to save checkpoints during training. If None, no end-of-epoch checkpoints are generated.
5362
save_every_n_eval_epochs: Frequency of evaluation epochs with which to save checkpoints during training. Use this if wanting to save checkpoints after every eval epoch during fit.
54-
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead.
55-
best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint.
63+
save_every_n_eval_steps: Frequency of evaluation steps with which to save checkpoints during training. Use this if wanting to save checkpoints during evaluate.
64+
save_every_n_predict_steps: Frequency of prediction steps with which to save checkpoints during training. Use this if wanting to save checkpoints during using predict entrypoint.
65+
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted
66+
to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead. Only supported for train or fit entrypoints.
67+
best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint. This param is ignored if not in train or fit entrypoints.
5668
process_group: The process group on which the ranks will communicate on. If the process group is not gloo-based, a new gloo-based process group will be created.
5769
5870
Note:
@@ -78,6 +90,8 @@ def __init__(
7890
save_every_n_train_steps: Optional[int] = None,
7991
save_every_n_epochs: Optional[int] = None,
8092
save_every_n_eval_epochs: Optional[int] = None,
93+
save_every_n_eval_steps: Optional[int] = None,
94+
save_every_n_predict_steps: Optional[int] = None,
8195
keep_last_n_checkpoints: Optional[int] = None,
8296
best_checkpoint_config: Optional[BestCheckpointConfig] = None,
8397
process_group: Optional[dist.ProcessGroup] = None,
@@ -90,12 +104,23 @@ def __init__(
90104
raise ValueError(
91105
f"Invalid value passed for save_every_n_epochs. Expected to receive either None or positive number, but received {save_every_n_epochs}"
92106
)
107+
if save_every_n_eval_steps is not None and save_every_n_eval_steps <= 0:
108+
raise ValueError(
109+
f"Invalid value passed for save_every_n_eval_steps. Expected to receive either None or positive number, but received {save_every_n_eval_steps}"
110+
)
111+
if save_every_n_eval_epochs is not None and save_every_n_eval_epochs <= 0:
112+
raise ValueError(
113+
f"Invalid value passed for save_every_n_eval_epochs. Expected to receive either None or positive number, but received {save_every_n_eval_epochs}"
114+
)
115+
if save_every_n_predict_steps is not None and save_every_n_predict_steps <= 0:
116+
raise ValueError(
117+
f"Invalid value passed for save_every_n_predict_steps. Expected to receive either None or positive number, but received {save_every_n_predict_steps}"
118+
)
93119
if keep_last_n_checkpoints is not None and keep_last_n_checkpoints <= 0:
94120
raise ValueError(
95121
f"Invalid value passed for keep_last_n_checkpoints. Expected to receive either None or positive number, but received {keep_last_n_checkpoints}"
96122
)
97123

98-
self._best_checkpoint_config = best_checkpoint_config
99124
if best_checkpoint_config and best_checkpoint_config.mode not in {"min", "max"}:
100125
raise ValueError(
101126
f"Invalid value passed for best_checkpoint_config.mode. Expected to receive 'min' or 'max', but received {best_checkpoint_config.mode}"
@@ -104,7 +129,10 @@ def __init__(
104129
self._save_every_n_train_steps = save_every_n_train_steps
105130
self._save_every_n_epochs = save_every_n_epochs
106131
self._save_every_n_eval_epochs = save_every_n_eval_epochs
132+
self._save_every_n_eval_steps = save_every_n_eval_steps
133+
self._save_every_n_predict_steps = save_every_n_predict_steps
107134
self._keep_last_n_checkpoints = keep_last_n_checkpoints
135+
self._best_checkpoint_config = best_checkpoint_config
108136

109137
self._process_group: Optional[dist.ProcessGroup] = None
110138
self._setup_gloo_pg(process_group)
@@ -147,7 +175,7 @@ def dirpath(self) -> str:
147175
return self._checkpoint_manager.dirpath
148176

149177
def _generate_checkpoint_and_upkeep(
150-
self, state: State, unit: Union[TTrainUnit, TEvalUnit], hook: str
178+
self, state: State, unit: Union[TTrainUnit, TEvalUnit, TPredictUnit], hook: str
151179
) -> bool:
152180
"""
153181
Implementation for saving checkpoint while taking care of checkpoint
@@ -162,11 +190,16 @@ def _generate_checkpoint_and_upkeep(
162190
True if checkpoint was successfully saved. False otherwise.
163191
"""
164192
# 1) generate checkpoint name
165-
epoch = cast(TTrainUnit, unit).train_progress.num_epochs_completed
193+
epoch = _get_epoch(state, unit)
166194
step_mapping = _get_step_phase_mapping(state, unit)
167195

196+
# 1.1) append metric data only for train checkpoints, if best_checkpoint_config is defined
168197
metric_data: Optional[MetricData] = None
169-
if metric_value := self._get_tracked_metric_value(unit):
198+
if (
199+
self._best_checkpoint_config
200+
and state.active_phase == ActivePhase.TRAIN
201+
and (metric_value := self._get_tracked_metric_value(cast(TTrainUnit, unit)))
202+
):
170203
metric_data = MetricData(
171204
name=none_throws(self._best_checkpoint_config).monitored_metric,
172205
value=metric_value,
@@ -179,7 +212,8 @@ def _generate_checkpoint_and_upkeep(
179212
process_group=self._process_group,
180213
)
181214

182-
# 2) Determine if we should save checkpoint
215+
# 2) Determine if we should save checkpoint. This is a no-op for eval and predict entrypoints
216+
# since neither best_checkpoint_config nor keep_last_n_checkpoints are supported.
183217
if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path):
184218
return False
185219

@@ -222,9 +256,7 @@ def _generate_checkpoint_and_upkeep(
222256

223257
return True
224258

225-
def _get_tracked_metric_value(
226-
self, unit: Union[TTrainUnit, TEvalUnit]
227-
) -> Optional[float]:
259+
def _get_tracked_metric_value(self, unit: TTrainUnit) -> Optional[float]:
228260
"""
229261
If the checkpointer has a tracked metric, look the value in the unit using reflection, and cast to float.
230262
@@ -271,33 +303,80 @@ def on_train_start(self, state: State, unit: TTrainUnit) -> None:
271303

272304
def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
273305
num_steps_completed = unit.train_progress.num_steps_completed
274-
save_every_n_train_steps = self._save_every_n_train_steps
275306
if (
276-
save_every_n_train_steps is None
277-
or num_steps_completed % save_every_n_train_steps != 0
307+
not self._save_every_n_train_steps
308+
or num_steps_completed % self._save_every_n_train_steps != 0
278309
):
279310
return
280311

281312
self._generate_checkpoint_and_upkeep(state, unit, hook="on_train_step_end")
282313

283314
def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
284315
epoch = unit.train_progress.num_epochs_completed
285-
save_every_n_epochs = self._save_every_n_epochs
286-
if save_every_n_epochs is None or epoch % save_every_n_epochs != 0:
316+
if not self._save_every_n_epochs or epoch % self._save_every_n_epochs != 0:
287317
return
288318

289319
self._generate_checkpoint_and_upkeep(state, unit, hook="on_train_epoch_end")
290320

321+
def on_train_end(self, state: State, unit: TTrainUnit) -> None:
322+
self._generate_checkpoint_and_upkeep(state, unit, hook="on_train_end")
323+
324+
def on_eval_start(self, state: State, unit: TEvalUnit) -> None:
325+
if state.entry_point == EntryPoint.EVALUATE:
326+
self._disable_ckpt_optimality_tracking()
327+
328+
def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
329+
num_steps_completed = unit.eval_progress.num_steps_completed
330+
if (
331+
not self._save_every_n_eval_steps
332+
or num_steps_completed % self._save_every_n_eval_steps != 0
333+
):
334+
return
335+
336+
self._generate_checkpoint_and_upkeep(state, unit, hook="on_eval_step_end")
337+
291338
def on_eval_epoch_end(self, state: State, unit: TEvalUnit) -> None:
292339
epoch = unit.eval_progress.num_epochs_completed
293-
save_every_n_eval_epochs = self._save_every_n_eval_epochs
294-
if save_every_n_eval_epochs is None or epoch % save_every_n_eval_epochs != 0:
340+
if (
341+
not self._save_every_n_eval_epochs
342+
or epoch % self._save_every_n_eval_epochs != 0
343+
):
295344
return
296345

297346
self._generate_checkpoint_and_upkeep(state, unit, hook="on_eval_epoch_end")
298347

299-
def on_train_end(self, state: State, unit: TTrainUnit) -> None:
300-
self._generate_checkpoint_and_upkeep(state, unit, hook="on_train_end")
348+
def on_predict_start(self, state: State, unit: TPredictUnit) -> None:
349+
self._disable_ckpt_optimality_tracking()
350+
351+
def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
352+
num_steps_completed = unit.predict_progress.num_steps_completed
353+
if (
354+
not self._save_every_n_predict_steps
355+
or num_steps_completed % self._save_every_n_predict_steps != 0
356+
):
357+
return
358+
359+
self._generate_checkpoint_and_upkeep(state, unit, hook="on_predict_step_end")
360+
361+
def _disable_ckpt_optimality_tracking(self) -> None:
362+
"""
363+
Disables checkpoint optimality tracking. This means that best_checkpoint and keep_last_n_checkpoints
364+
will not be used. This is useful for eval and predict entrypoints, since checkpoints do not include
365+
model parameters.
366+
"""
367+
if self._best_checkpoint_config:
368+
logger.warning(
369+
"Disabling best_checkpoint_config, since it is not supported for eval or predict entrypoints."
370+
)
371+
self._best_checkpoint_config = None
372+
self._checkpoint_manager._best_checkpoint_config = None
373+
374+
if self._keep_last_n_checkpoints:
375+
logger.warning(
376+
"Disabling keep_last_n_checkpoints, since is not supported for eval or predict entrypoints."
377+
)
378+
self._keep_last_n_checkpoints = None
379+
self._checkpoint_manager._keep_last_n_checkpoints = None
301380

302381
@abc.abstractmethod
303382
def _checkpoint_impl(

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def restore_with_id(
318318
)
319319

320320
def _generate_checkpoint_and_upkeep(
321-
self, state: State, unit: Union[TTrainUnit, TEvalUnit], hook: str
321+
self, state: State, unit: Union[TTrainUnit, TEvalUnit, TPredictUnit], hook: str
322322
) -> bool:
323323
# if we are still checkpointing, this might cause a collective hang, since several
324324
# operations in the base class use the process group. So wait here instead.

torchtnt/framework/unit.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,16 @@ def on_predict_epoch_end(self, state: State) -> None:
586586
"""
587587
pass
588588

589+
def on_checkpoint_save(self, state: State, checkpoint_id: str) -> None:
590+
"""Hook called after successfully saving a checkpoint.
591+
592+
Args:
593+
state: a :class:`~torchtnt.framework.state.State` object containing metadata about the training run.
594+
checkpoint_id: the ID of the checkpoint that was saved. Depending on the storage type, this may be
595+
a path, a URL or a unique identifier.
596+
"""
597+
pass
598+
589599
def on_predict_end(self, state: State) -> None:
590600
"""Hook called after prediction ends.
591601

0 commit comments

Comments
 (0)