Skip to content

Commit 7bfdee4

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Use no_grad instead of inference_mode for predict with checkpointing (#912)
Summary: Pull Request resolved: #912 Reviewed By: mvsfb, JKSenthil Differential Revision: D63491475 fbshipit-source-id: f1772ef607a348ba6e65a5761d746b57b9ba814c
1 parent 1f06115 commit 7bfdee4

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

tests/framework/test_predict.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
# pyre-strict
99

1010
import unittest
11-
from typing import Any, Iterator, Mapping, Tuple
12-
from unittest.mock import MagicMock
11+
from typing import Any, cast, Iterator, List, Mapping, Tuple
12+
from unittest.mock import MagicMock, patch
1313

1414
import torch
1515
from torch import nn
1616

1717
from torchtnt.framework._test_utils import DummyPredictUnit, generate_random_dataloader
1818
from torchtnt.framework.callback import Callback
19+
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
1920

2021
from torchtnt.framework.predict import predict
2122
from torchtnt.framework.state import State
@@ -223,6 +224,24 @@ def test_error_message(self) -> None:
223224
log.output,
224225
)
225226

227+
def test_predict_ckpt_autograd_mode(
228+
self,
229+
) -> None:
230+
"""
231+
Verify that the pytorch autograd mode used depends on having a checkpoint callback in predict.
232+
"""
233+
unit = DummyPredictUnit(2)
234+
dataloader = generate_random_dataloader(10, 2, 2)
235+
dcp_saver = DistributedCheckpointSaver(dirpath="dummy_dirpath")
236+
237+
for callbacks, mock_autograd_mode in [
238+
([], "torch.inference_mode"),
239+
([dcp_saver], "torch.no_grad"),
240+
]:
241+
with patch(mock_autograd_mode) as mock_autograd_mode:
242+
predict(unit, dataloader, callbacks=cast(List[Callback], callbacks))
243+
mock_autograd_mode.assert_called_once()
244+
226245

227246
Batch = Tuple[torch.Tensor, torch.Tensor]
228247

torchtnt/framework/predict.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
_set_module_training_mode,
2121
)
2222
from torchtnt.framework.callback import Callback
23+
from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer
2324
from torchtnt.framework.state import ActivePhase, EntryPoint, PhaseState, State
2425
from torchtnt.framework.unit import TPredictData, TPredictUnit
2526
from torchtnt.framework.utils import get_timing_context
@@ -80,7 +81,10 @@ def predict(
8081
call on_predict_end on unit first and then callbacks
8182
"""
8283
_log_api_usage("predict")
83-
callback_handler = CallbackHandler(callbacks or [])
84+
callbacks = callbacks or []
85+
callback_handler = CallbackHandler(callbacks)
86+
checkpoint_cb_exists = any(isinstance(cb, BaseCheckpointer) for cb in callbacks)
87+
8488
state = State(
8589
entry_point=EntryPoint.PREDICT,
8690
predict_state=PhaseState(
@@ -90,7 +94,13 @@ def predict(
9094
timer=timer,
9195
)
9296
try:
93-
_predict_impl(state, predict_unit, callback_handler)
97+
# all_gather using inference_mode with gloo backend is not supported. Since this collective
98+
# is necessary for checkpointing, we need to use torch.no_grad instead.
99+
# TODO: remove this once all_gather is supported in inference_mode.
100+
inference_ctx = torch.no_grad if checkpoint_cb_exists else torch.inference_mode
101+
with inference_ctx():
102+
_predict_impl(state, predict_unit, callback_handler)
103+
94104
logger.info("Finished predict")
95105
if state.timer:
96106
logger.info(get_timer_summary(state.timer))
@@ -104,7 +114,6 @@ def predict(
104114
raise e
105115

106116

107-
@torch.inference_mode()
108117
def _predict_impl(
109118
state: State,
110119
predict_unit: TPredictUnit,

0 commit comments

Comments
 (0)