Skip to content

Commit 6d99aae

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Complete utils to get step and epoch according to active_phase (#913)
Summary: Pull Request resolved: #913 Reviewed By: anshulverma, shalgi Differential Revision: D63491481 fbshipit-source-id: 137bf317d7d3da48e3fb681b2115b94ecb0212db
1 parent 764f36e commit 6d99aae

File tree

3 files changed

+64
-2
lines changed

3 files changed

+64
-2
lines changed

tests/framework/callbacks/test_checkpoint_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
generate_dummy_stateful_dataloader,
2222
get_dummy_eval_state,
2323
get_dummy_fit_state,
24+
get_dummy_predict_state,
2425
get_dummy_train_state,
2526
)
2627

2728
from torchtnt.framework.callbacks._checkpoint_utils import (
29+
_get_epoch,
2830
_get_step_phase_mapping,
2931
_prepare_app_state_for_checkpoint,
3032
)
@@ -87,6 +89,7 @@ def test_get_step_phase_mapping(self) -> None:
8789
unit = DummyAutoUnit(module=nn.Linear(2, 2))
8890
unit.train_progress._num_steps_completed = 5
8991
unit.eval_progress._num_steps_completed = 7
92+
unit.predict_progress._num_steps_completed = 9
9093

9194
fit_state = get_dummy_fit_state()
9295
self.assertEqual(
@@ -99,3 +102,26 @@ def test_get_step_phase_mapping(self) -> None:
99102

100103
eval_state = get_dummy_eval_state()
101104
self.assertEqual({Phase.EVALUATE: 7}, _get_step_phase_mapping(eval_state, unit))
105+
106+
predict_state = get_dummy_predict_state()
107+
self.assertEqual(
108+
{Phase.PREDICT: 9}, _get_step_phase_mapping(predict_state, unit)
109+
)
110+
111+
def test_get_epoch(self) -> None:
112+
unit = DummyAutoUnit(module=nn.Linear(2, 2))
113+
unit.train_progress._num_epochs_completed = 1
114+
unit.eval_progress._num_epochs_completed = 2
115+
unit.predict_progress._num_epochs_completed = 3
116+
117+
fit_state = get_dummy_fit_state()
118+
self.assertEqual(1, _get_epoch(fit_state, unit))
119+
120+
train_state = get_dummy_train_state()
121+
self.assertEqual(1, _get_epoch(train_state, unit))
122+
123+
eval_state = get_dummy_eval_state()
124+
self.assertEqual(2, _get_epoch(eval_state, unit))
125+
126+
predict_state = get_dummy_predict_state()
127+
self.assertEqual(3, _get_epoch(predict_state, unit))

torchtnt/framework/_test_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,19 @@ def get_dummy_eval_state(dataloader: Optional[Iterable[object]] = None) -> State
4646
)
4747

4848

49+
def get_dummy_predict_state(dataloader: Optional[Iterable[object]] = None) -> State:
50+
return State(
51+
entry_point=EntryPoint.PREDICT,
52+
predict_state=PhaseState(
53+
dataloader=dataloader or [1, 2, 3, 4],
54+
max_epochs=1,
55+
max_steps=1,
56+
max_steps_per_epoch=1,
57+
),
58+
timer=None,
59+
)
60+
61+
4962
def get_dummy_fit_state() -> State:
5063
return State(
5164
entry_point=EntryPoint.FIT,

torchtnt/framework/callbacks/_checkpoint_utils.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
1313
from torchtnt.framework.state import EntryPoint, State
14-
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainUnit
14+
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TPredictUnit, TTrainUnit
1515
from torchtnt.utils.checkpoint import Phase
1616

1717
from torchtnt.utils.stateful import Stateful
@@ -31,7 +31,7 @@
3131

3232

3333
def _get_step_phase_mapping(
34-
state: State, unit: Union[TTrainUnit, TEvalUnit]
34+
state: State, unit: Union[TTrainUnit, TEvalUnit, TPredictUnit]
3535
) -> Dict[Phase, int]:
3636
"""
3737
Returns a mapping of phase to step, depending on the entrypoint.
@@ -47,9 +47,32 @@ def _get_step_phase_mapping(
4747
eval_unit = cast(TEvalUnit, unit)
4848
step_mapping[Phase.EVALUATE] = eval_unit.eval_progress.num_steps_completed
4949

50+
if state.entry_point == EntryPoint.PREDICT:
51+
predict_unit = cast(TPredictUnit, unit)
52+
step_mapping[Phase.PREDICT] = predict_unit.predict_progress.num_steps_completed
53+
5054
return step_mapping
5155

5256

57+
def _get_epoch(state: State, unit: Union[TTrainUnit, TEvalUnit, TPredictUnit]) -> int:
58+
"""
59+
Returns the epoch depending on the entrypoint. For FIT, it always returns the train epoch.
60+
"""
61+
if state.entry_point in (EntryPoint.TRAIN, EntryPoint.FIT):
62+
train_unit = cast(TTrainUnit, unit)
63+
return train_unit.train_progress.num_epochs_completed
64+
65+
elif state.entry_point == EntryPoint.PREDICT:
66+
predict_unit = cast(TPredictUnit, unit)
67+
return predict_unit.predict_progress.num_epochs_completed
68+
69+
elif state.entry_point == EntryPoint.EVALUATE:
70+
eval_unit = cast(TEvalUnit, unit)
71+
return eval_unit.eval_progress.num_epochs_completed
72+
73+
raise ValueError(f"Unknown entrypoint: {state.entry_point}")
74+
75+
5376
def _prepare_app_state(unit: AppStateMixin) -> Dict[str, Any]:
5477
"""Join together all of the tracked stateful entities to simplify registration of snapshottable states, deals with FSDP case"""
5578
app_state = unit.app_state()

0 commit comments

Comments
 (0)