Skip to content

Commit 93347d9

Browse files
clarkdykangfacebook-github-bot
authored andcommitted
Set evaluate_every_epoch value based on number of train epochs completed (#968)
Summary: Pull Request resolved: #968 - Use `evaluate_every_n_epoch` until having trained for a certain amount of epochs, after which we evaluate every epoch forcefully - Need to be able to set the `evaluate_every_n_epoch` field of the `PhaseState` class - Save total time by cutting down evaluations, 8.3 hours -> 6.3 hours Reviewed By: galrotem Differential Revision: D68933995 fbshipit-source-id: a16bbcc81482681cd297cbcc6dc01764db8bb097
1 parent d71a41b commit 93347d9

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

tests/framework/test_state.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,20 @@ def test_active_phase_into_phase(self) -> None:
5252

5353
predict_phase = ActivePhase.PREDICT
5454
self.assertEqual(predict_phase.into_phase(), Phase.PREDICT)
55+
56+
def test_set_evaluate_every_n_steps_or_epochs(self) -> None:
57+
state = PhaseState(dataloader=[], evaluate_every_n_steps=2)
58+
state.evaluate_every_n_steps = None
59+
state.evaluate_every_n_steps = 100
60+
with self.assertRaisesRegex(
61+
ValueError, "Invalid value provided for evaluate_every_n_steps"
62+
):
63+
state.evaluate_every_n_steps = -2
64+
65+
state = PhaseState(dataloader=[], evaluate_every_n_epochs=2)
66+
state.evaluate_every_n_epochs = None
67+
state.evaluate_every_n_epochs = 100
68+
with self.assertRaisesRegex(
69+
ValueError, "Invalid value provided for evaluate_every_n_epochs"
70+
):
71+
state.evaluate_every_n_epochs = -2

torchtnt/framework/state.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,21 @@ def evaluate_every_n_steps(self) -> Optional[int]:
133133
"""Frequency with which to evaluate in terms of training steps, when running :func:`~torchtnt.framework.fit`. Defined by the user."""
134134
return self._evaluate_every_n_steps
135135

136+
@evaluate_every_n_steps.setter
137+
def evaluate_every_n_steps(self, value: Optional[int]) -> None:
138+
_check_loop_condition("evaluate_every_n_steps", value)
139+
self._evaluate_every_n_steps = value
140+
136141
@property
137142
def evaluate_every_n_epochs(self) -> Optional[int]:
138143
"""Frequency with which to evaluate in terms of training epochs, when running :func:`~torchtnt.framework.fit`. Defined by the user."""
139144
return self._evaluate_every_n_epochs
140145

146+
@evaluate_every_n_epochs.setter
147+
def evaluate_every_n_epochs(self, value: Optional[int]) -> None:
148+
_check_loop_condition("evaluate_every_n_epochs", value)
149+
self._evaluate_every_n_epochs = value
150+
141151
@property
142152
def step_output(self) -> Optional[TStepOutput]:
143153
"""Output of the last step."""

0 commit comments

Comments
 (0)