Skip to content

Commit 86e11f3

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Conditionally include also train_dataloader in fit-eval checkpoints (#918)
Summary: Pull Request resolved: #918 Reviewed By: JKSenthil Differential Revision: D63919225 fbshipit-source-id: 515eb69568ad508ca915253f35565e8bd469c477
1 parent 84d1be9 commit 86e11f3

File tree

2 files changed

+51
-7
lines changed

2 files changed

+51
-7
lines changed

tests/framework/callbacks/test_checkpoint_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torchtnt.framework._test_utils import (
1717
DummyAutoUnit,
1818
DummyEvalUnit,
19+
DummyFitUnit,
1920
DummyMeanMetric,
2021
DummyTrainUnit,
2122
generate_dummy_stateful_dataloader,
@@ -85,6 +86,33 @@ def test_get_app_state(self) -> None:
8586
],
8687
)
8788

89+
# Test evaluate intra-epoch within train epoch on FIT (evaluate_every_n_steps)
90+
my_unit = DummyFitUnit(input_dim=2)
91+
my_unit.train_progress.increment_step() # Simulate at least one step in each phase
92+
my_unit.eval_progress.increment_step()
93+
94+
state = get_dummy_fit_state()
95+
state._active_phase = ActivePhase.EVALUATE
96+
97+
train_dl = generate_dummy_stateful_dataloader(1, 1, 1)
98+
eval_dl = generate_dummy_stateful_dataloader(1, 1, 1)
99+
none_throws(state.train_state)._dataloader = train_dl
100+
none_throws(state.eval_state)._dataloader = eval_dl
101+
102+
app_state = _prepare_app_state_for_checkpoint(state, my_unit, intra_epoch=True)
103+
self.assertCountEqual(
104+
app_state.keys(),
105+
[
106+
"module",
107+
"optimizer",
108+
"loss_fn",
109+
"train_progress",
110+
"eval_progress",
111+
"train_dataloader",
112+
"eval_dataloader",
113+
],
114+
)
115+
88116
def test_get_step_phase_mapping(self) -> None:
89117
unit = DummyAutoUnit(module=nn.Linear(2, 2))
90118
unit.train_progress._num_steps_completed = 5

torchtnt/framework/callbacks/_checkpoint_utils.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99

1010
from typing import Any, cast, Dict, Union
1111

12+
from pyre_extensions import none_throws
13+
1214
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
13-
from torchtnt.framework.state import EntryPoint, State
15+
from torchtnt.framework.state import ActivePhase, EntryPoint, State
1416
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TPredictUnit, TTrainUnit
1517
from torchtnt.utils.checkpoint import Phase
1618

@@ -123,13 +125,27 @@ def _prepare_app_state_for_checkpoint(
123125
remove_lr_schedulers=True,
124126
)
125127

128+
if not intra_epoch:
129+
return app_state
130+
126131
# for intra-epoch checkpointing, include dataloader state of the current phase
127-
phase_dl = state.active_phase_state().dataloader
128-
if intra_epoch and isinstance(phase_dl, Stateful):
129-
dataloader_state_key = _PHASE_DL_STATE_KEY_MAPPING[
130-
state.active_phase.into_phase()
131-
]
132-
app_state[dataloader_state_key] = phase_dl
132+
active_dataloaders = {state.active_phase: state.active_phase_state().dataloader}
133+
134+
# Special case for FIT where eval is executed every n steps. We also need to save
135+
# the train dataloader state. In this case, train epoch wouldn't be incremented yet.
136+
if (
137+
state.entry_point == EntryPoint.FIT
138+
and state.active_phase == ActivePhase.EVALUATE
139+
and cast(TTrainUnit, unit).train_progress.num_steps_completed_in_epoch != 0
140+
):
141+
active_dataloaders[ActivePhase.TRAIN] = none_throws(
142+
state.train_state
143+
).dataloader
144+
145+
for active_phase, dl in active_dataloaders.items():
146+
if isinstance(dl, Stateful):
147+
dl_key = _PHASE_DL_STATE_KEY_MAPPING[active_phase.into_phase()]
148+
app_state[dl_key] = dl
133149

134150
return app_state
135151

0 commit comments

Comments
 (0)