Skip to content

Commit 47fbf01

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Include eval/predict dataloaders in save path (#909)
Summary: Pull Request resolved: #909 Reviewed By: JKSenthil Differential Revision: D63013007 fbshipit-source-id: b2972310ddd39bb91b64dbdcebc06c5cbd1f3035
1 parent 1064d10 commit 47fbf01

File tree

5 files changed

+115
-9
lines changed

5 files changed

+115
-9
lines changed

tests/framework/callbacks/test_checkpoint_utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,16 @@
88

99
import unittest
1010

11+
from pyre_extensions import none_throws
12+
1113
from torch import nn
14+
from torchtnt.framework import ActivePhase
1215

1316
from torchtnt.framework._test_utils import (
1417
DummyAutoUnit,
18+
DummyMeanMetric,
1519
DummyTrainUnit,
20+
generate_dummy_stateful_dataloader,
1621
get_dummy_eval_state,
1722
get_dummy_fit_state,
1823
get_dummy_train_state,
@@ -28,15 +33,37 @@
2833
class CheckpointUtilsTest(unittest.TestCase):
2934

3035
def test_get_app_state(self) -> None:
36+
37+
# Test end-of-epoch checkpoint
3138
my_unit = DummyTrainUnit(input_dim=2)
3239
state = get_dummy_train_state()
33-
3440
app_state = _prepare_app_state_for_checkpoint(state, my_unit, intra_epoch=False)
3541
self.assertCountEqual(
3642
app_state.keys(),
3743
["module", "optimizer", "loss_fn", "train_progress"],
3844
)
3945

46+
# Test train intra-epoch checkpoint
47+
my_unit = DummyTrainUnit(input_dim=2)
48+
my_unit.mean_metric = DummyMeanMetric() # pyre-ignore[16]
49+
state = get_dummy_train_state()
50+
stateful_dl = generate_dummy_stateful_dataloader(1, 1, 1)
51+
state._active_phase = ActivePhase.TRAIN
52+
none_throws(state.train_state)._dataloader = stateful_dl
53+
54+
app_state = _prepare_app_state_for_checkpoint(state, my_unit, intra_epoch=True)
55+
self.assertCountEqual(
56+
app_state.keys(),
57+
[
58+
"module",
59+
"optimizer",
60+
"loss_fn",
61+
"train_progress",
62+
"train_dataloader",
63+
"mean_metric",
64+
],
65+
)
66+
4067
def test_get_step_phase_mapping(self) -> None:
4168
unit = DummyAutoUnit(module=nn.Linear(2, 2))
4269
unit.train_progress._num_steps_completed = 5

tests/framework/test_state.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99

1010
import unittest
1111

12+
from torchtnt.framework import ActivePhase
13+
1214
from torchtnt.framework.state import _check_loop_condition, PhaseState
15+
from torchtnt.utils.checkpoint import Phase
1316

1417

1518
class StateTest(unittest.TestCase):
@@ -39,3 +42,13 @@ def test_phase_state_validation(self) -> None:
3942
ValueError, "Invalid value provided for evaluate_every_n_epochs"
4043
):
4144
PhaseState(dataloader=[], evaluate_every_n_epochs=-2)
45+
46+
def test_active_phase_into_phase(self) -> None:
47+
active_phase = ActivePhase.TRAIN
48+
self.assertEqual(active_phase.into_phase(), Phase.TRAIN)
49+
50+
eval_phase = ActivePhase.EVALUATE
51+
self.assertEqual(eval_phase.into_phase(), Phase.EVALUATE)
52+
53+
predict_phase = ActivePhase.PREDICT
54+
self.assertEqual(predict_phase.into_phase(), Phase.PREDICT)

torchtnt/framework/_test_utils.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
# pyre-strict
99

10-
from typing import Iterable, Iterator, List, Optional, Tuple
10+
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple
1111

1212
import torch
1313
from torch import nn, Tensor
@@ -236,3 +236,51 @@ def configure_optimizers_and_lr_scheduler(
236236
my_optimizer, gamma=0.9
237237
)
238238
return my_optimizer, my_lr_scheduler
239+
240+
241+
class DummyStatefulDataLoader:
242+
"""Dummy Dataloader that implements state_dict and load_state_dict"""
243+
244+
def __init__(self, dataloader: DataLoader) -> None:
245+
self.dataloader = dataloader
246+
247+
def state_dict(self) -> Dict[str, Any]:
248+
return {"current_batch": 1}
249+
250+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
251+
return None
252+
253+
def __iter__(self) -> Iterator[object]:
254+
return iter(self.dataloader)
255+
256+
257+
def generate_dummy_stateful_dataloader(
258+
num_samples: int, input_dim: int, batch_size: int
259+
) -> DummyStatefulDataLoader:
260+
return DummyStatefulDataLoader(
261+
DataLoader(
262+
dataset=RandomIterableDataset(input_dim, num_samples),
263+
batch_size=batch_size,
264+
)
265+
)
266+
267+
268+
class DummyMeanMetric:
269+
def __init__(self) -> None:
270+
super().__init__()
271+
self.sum: float = 0.0
272+
self.count: int = 0
273+
274+
def update(self, value: float) -> None:
275+
self.sum += value
276+
self.count += 1
277+
278+
def compute(self) -> float:
279+
return self.sum / self.count if self.count > 0 else 0.0
280+
281+
def state_dict(self) -> Dict[str, Any]:
282+
return {"sum": self.sum, "count": self.count}
283+
284+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
285+
self.sum = state_dict["sum"]
286+
self.count = state_dict["count"]

torchtnt/framework/callbacks/_checkpoint_utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from typing import Any, cast, Dict, Union
1111

12-
from pyre_extensions import none_throws
1312
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
1413
from torchtnt.framework.state import EntryPoint, State
1514
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainUnit
@@ -19,8 +18,13 @@
1918

2019

2120
# keys for use when checkpointing
22-
_TRAIN_PROGRESS_STATE_KEY = "train_progress"
21+
_PHASE_DL_STATE_KEY_MAPPING: Dict[Phase, str] = {
22+
Phase.TRAIN: "train_dataloader",
23+
Phase.EVALUATE: "eval_dataloader",
24+
Phase.PREDICT: "predict_dataloader",
25+
}
2326
_TRAIN_DL_STATE_KEY = "train_dataloader"
27+
_TRAIN_PROGRESS_STATE_KEY = "train_progress"
2428
_EVAL_PROGRESS_STATE_KEY = "eval_progress"
2529

2630

@@ -60,11 +64,13 @@ def _prepare_app_state_for_checkpoint(
6064
"""
6165
app_state = _prepare_app_state(unit)
6266

63-
# for intra-epoch checkpointing, include dataloader states
64-
train_state = none_throws(state.train_state)
65-
train_dl = train_state.dataloader
66-
if intra_epoch and isinstance(train_dl, Stateful):
67-
app_state[_TRAIN_DL_STATE_KEY] = train_dl
67+
# for intra-epoch checkpointing, include dataloader state of the current phase
68+
phase_dl = state.active_phase_state().dataloader
69+
if intra_epoch and isinstance(phase_dl, Stateful):
70+
dataloader_state_key = _PHASE_DL_STATE_KEY_MAPPING[
71+
state.active_phase.into_phase()
72+
]
73+
app_state[dataloader_state_key] = phase_dl
6874

6975
return app_state
7076

torchtnt/framework/state.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import Generic, Iterable, Optional, TypeVar
1414

1515
from pyre_extensions import none_throws
16+
from torchtnt.utils.checkpoint import Phase
1617

1718
from torchtnt.utils.timer import BoundedTimer, TimerProtocol
1819

@@ -62,6 +63,17 @@ class ActivePhase(Enum):
6263
EVALUATE = auto()
6364
PREDICT = auto()
6465

66+
def into_phase(self) -> Phase:
67+
"""Converts the active phase to the corresponding phase."""
68+
if self == ActivePhase.TRAIN:
69+
return Phase.TRAIN
70+
elif self == ActivePhase.EVALUATE:
71+
return Phase.EVALUATE
72+
elif self == ActivePhase.PREDICT:
73+
return Phase.PREDICT
74+
else:
75+
raise AssertionError("Should match an ActivePhase")
76+
6577

6678
class PhaseState(Generic[TData, TStepOutput]):
6779
"""State for each phase (train, eval, predict).

0 commit comments

Comments
 (0)