Skip to content

Commit 45e1138

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Early exit on predict entrypoint if epoch has completed (#972)
Summary: Pull Request resolved: #972 Reviewed By: galrotem, anshulverma Differential Revision: D69865409 fbshipit-source-id: 814dd7081ed7489c63eb021c1bd1ff3abaeb47c7
1 parent 2e6cd59 commit 45e1138

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

tests/framework/test_predict.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torchtnt.framework.predict import predict
2222
from torchtnt.framework.state import State
2323
from torchtnt.framework.unit import PredictUnit, TPredictUnit
24+
from torchtnt.utils.progress import Progress
2425
from torchtnt.utils.timer import Timer
2526

2627

@@ -242,6 +243,16 @@ def test_predict_ckpt_autograd_mode(
242243
predict(unit, dataloader, callbacks=cast(List[Callback], callbacks))
243244
mock_autograd_mode.assert_called_once()
244245

246+
def test_predict_epoch_check(self) -> None:
247+
unit = MagicMock(wraps=DummyPredictUnit(2))
248+
unit.predict_progress = Progress(num_epochs_completed=1, num_steps_completed=5)
249+
250+
dataloader = generate_random_dataloader(10, 2, 2)
251+
252+
predict(unit, dataloader, max_steps_per_epoch=100)
253+
254+
unit.on_predict_start.assert_not_called()
255+
245256

246257
Batch = Tuple[torch.Tensor, torch.Tensor]
247258

torchtnt/framework/predict.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ def _predict_impl(
122122
# input validation
123123
predict_state = none_throws(state.predict_state)
124124

125+
if predict_unit.predict_progress.num_epochs_completed >= 1:
126+
logger.warning(
127+
"Predict epoch has already been completed. Skipping to avoid duplicate outputs."
128+
)
129+
return
130+
125131
state._active_phase = ActivePhase.PREDICT
126132
logger.info(
127133
f"Started predict with max_steps_per_epoch={predict_state.max_steps_per_epoch}"

0 commit comments

Comments
 (0)