Skip to content

Commit edf6f85

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
disable detect anomaly if torch compile enabled (#961)
Summary: Pull Request resolved: #961 Reviewed By: diego-urgell Differential Revision: D68336316 fbshipit-source-id: 84e005cfcf8ebea828b64d592dc1b298369fac7d
1 parent 52b5568 commit edf6f85

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

tests/framework/test_auto_unit.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from torchtnt.utils.distributed import spawn_multi_process
4141
from torchtnt.utils.env import init_from_env
4242
from torchtnt.utils.lr_scheduler import TLRScheduler
43-
from torchtnt.utils.prepare_module import DDPStrategy, FSDPStrategy
43+
from torchtnt.utils.prepare_module import DDPStrategy, FSDPStrategy, TorchCompileParams
4444
from torchtnt.utils.progress import Progress
4545
from torchtnt.utils.swa import _AVERAGED_MODEL_AVAIL
4646
from torchtnt.utils.test_utils import skip_if_not_distributed
@@ -741,6 +741,15 @@ def test_enable_prefetch(self) -> None:
741741
_ = auto_unit._get_next_batch(get_dummy_train_state(), iter(data))
742742
self.assertIsNone(auto_unit._phase_to_next_batch[ActivePhase.TRAIN])
743743

744+
def test_detect_anomaly_disabled_with_torch_compile(self) -> None:
745+
auto_unit = DummyAutoUnit(
746+
module=torch.nn.Linear(2, 2),
747+
detect_anomaly=True,
748+
torch_compile_params=TorchCompileParams(),
749+
)
750+
751+
self.assertIsNone(auto_unit.detect_anomaly)
752+
744753

745754
Batch = Tuple[torch.Tensor, torch.Tensor]
746755

torchtnt/framework/auto_unit.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,11 @@ def __init__(
183183
)
184184

185185
self.detect_anomaly = detect_anomaly
186+
if torch_compile_params is not None:
187+
# torch compile is not compatible with detect anomaly
188+
# so we disable detect anomaly if torch compile is enabled
189+
self.detect_anomaly = None
190+
_logger.warning("torch.compile is enabled, so detect_anomaly is disabled")
186191

187192
# create autocast context based on precision and device type
188193
self.maybe_autocast_precision = torch.autocast(

0 commit comments

Comments
 (0)