Skip to content

Commit 764f36e

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Prepare app_state for eval/predict checkpoints (#908)
Summary: Pull Request resolved: #908 Reviewed By: JKSenthil Differential Revision: D63013009 fbshipit-source-id: f56673d7ff5114d1d7456f7c38511939f5e9bd56
1 parent 47fbf01 commit 764f36e

File tree

3 files changed

+66
-12
lines changed

3 files changed

+66
-12
lines changed

tests/framework/callbacks/test_checkpoint_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from torchtnt.framework._test_utils import (
1717
DummyAutoUnit,
18+
DummyEvalUnit,
1819
DummyMeanMetric,
1920
DummyTrainUnit,
2021
generate_dummy_stateful_dataloader,
@@ -64,6 +65,24 @@ def test_get_app_state(self) -> None:
6465
],
6566
)
6667

68+
# Test evaluate intra-epoch checkpoint
69+
my_unit = DummyEvalUnit(input_dim=2)
70+
my_unit.mean_metric = DummyMeanMetric() # pyre-ignore[16]
71+
state = get_dummy_eval_state()
72+
stateful_dl = generate_dummy_stateful_dataloader(1, 1, 1)
73+
state._active_phase = ActivePhase.EVALUATE
74+
none_throws(state.eval_state)._dataloader = stateful_dl
75+
76+
app_state = _prepare_app_state_for_checkpoint(state, my_unit, intra_epoch=True)
77+
self.assertCountEqual(
78+
app_state.keys(),
79+
[
80+
"eval_progress",
81+
"eval_dataloader",
82+
"mean_metric",
83+
],
84+
)
85+
6786
def test_get_step_phase_mapping(self) -> None:
6887
unit = DummyAutoUnit(module=nn.Linear(2, 2))
6988
unit.train_progress._num_steps_completed = 5

torchtnt/framework/callbacks/_checkpoint_utils.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
Phase.PREDICT: "predict_dataloader",
2525
}
2626
_TRAIN_DL_STATE_KEY = "train_dataloader"
27+
2728
_TRAIN_PROGRESS_STATE_KEY = "train_progress"
2829
_EVAL_PROGRESS_STATE_KEY = "eval_progress"
30+
_PREDICT_PROGRESS_STATE_KEY = "predict_progress"
2931

3032

3133
def _get_step_phase_mapping(
@@ -56,6 +58,30 @@ def _prepare_app_state(unit: AppStateMixin) -> Dict[str, Any]:
5658
return app_state
5759

5860

61+
def _remove_app_state_keys(
62+
unit: AppStateMixin,
63+
app_state: Dict[str, Any],
64+
*,
65+
remove_modules: bool = False,
66+
remove_optimizers: bool = False,
67+
remove_lr_schedulers: bool = False,
68+
) -> None:
69+
if remove_modules:
70+
# remove all module keys from app_state
71+
for module_keys in unit.tracked_modules().keys():
72+
app_state.pop(module_keys, None)
73+
74+
if remove_optimizers:
75+
# remove all optimizer keys from app_state
76+
for optim_keys in unit.tracked_optimizers().keys():
77+
app_state.pop(optim_keys, None)
78+
79+
if remove_lr_schedulers:
80+
# remove all lr scheduler keys from app_state
81+
for lr_scheduler_keys in unit.tracked_lr_schedulers().keys():
82+
app_state.pop(lr_scheduler_keys, None)
83+
84+
5985
def _prepare_app_state_for_checkpoint(
6086
state: State, unit: AppStateMixin, intra_epoch: bool
6187
) -> Dict[str, Stateful]:
@@ -64,6 +90,16 @@ def _prepare_app_state_for_checkpoint(
6490
"""
6591
app_state = _prepare_app_state(unit)
6692

93+
if state.entry_point in [EntryPoint.EVALUATE, EntryPoint.PREDICT]:
94+
# Since model parameters are fixed, remove them from checkpoint.
95+
_remove_app_state_keys(
96+
unit,
97+
app_state,
98+
remove_modules=True,
99+
remove_optimizers=True,
100+
remove_lr_schedulers=True,
101+
)
102+
67103
# for intra-epoch checkpointing, include dataloader state of the current phase
68104
phase_dl = state.active_phase_state().dataloader
69105
if intra_epoch and isinstance(phase_dl, Stateful):
@@ -85,24 +121,21 @@ def _prepare_app_state_for_restore(
85121

86122
restore_options = restore_options or RestoreOptions()
87123

88-
if not restore_options.restore_modules:
89-
for module_keys in unit.tracked_modules().keys():
90-
app_state.pop(module_keys, None)
91-
92124
if not restore_options.restore_train_progress:
93125
app_state.pop(_TRAIN_PROGRESS_STATE_KEY, None)
94126

95127
if not restore_options.restore_eval_progress:
96128
app_state.pop(_EVAL_PROGRESS_STATE_KEY, None)
97129

98-
if not restore_options.restore_optimizers:
99-
# remove all optimizer keys from app_state
100-
for optim_keys in unit.tracked_optimizers().keys():
101-
app_state.pop(optim_keys, None)
130+
if not restore_options.restore_predict_progress:
131+
app_state.pop(_PREDICT_PROGRESS_STATE_KEY, None)
102132

103-
if not restore_options.restore_lr_schedulers:
104-
# remove all lr scheduler keys from app_state
105-
for lr_scheduler_keys in unit.tracked_lr_schedulers().keys():
106-
app_state.pop(lr_scheduler_keys, None)
133+
_remove_app_state_keys(
134+
unit,
135+
app_state,
136+
remove_modules=not restore_options.restore_modules,
137+
remove_optimizers=not restore_options.restore_optimizers,
138+
remove_lr_schedulers=not restore_options.restore_lr_schedulers,
139+
)
107140

108141
return app_state

torchtnt/framework/callbacks/checkpointer_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class RestoreOptions:
3939
restore_modules: Whether to restore the module state dict.
4040
restore_train_progress: Whether to restore the training progress state.
4141
restore_eval_progress: Whether to restore the evaluation progress state.
42+
restore_predict_progress: Whether to restore the prediction progress state.
4243
restore_optimizers: Whether to restore the optimizer states.
4344
restore_lr_schedulers: Whether to restore the lr scheduler states.
4445
strict: Whether to strictly restore app state and the module state dict.
@@ -47,6 +48,7 @@ class RestoreOptions:
4748
restore_modules: bool = True
4849
restore_train_progress: bool = True
4950
restore_eval_progress: bool = True
51+
restore_predict_progress: bool = True
5052
restore_optimizers: bool = True
5153
restore_lr_schedulers: bool = True
5254
strict: bool = True

0 commit comments

Comments
 (0)