File tree Expand file tree Collapse file tree 2 files changed +15
-1
lines changed Expand file tree Collapse file tree 2 files changed +15
-1
lines changed Original file line number Diff line number Diff line change 40
40
from torchtnt .utils .distributed import spawn_multi_process
41
41
from torchtnt .utils .env import init_from_env
42
42
from torchtnt .utils .lr_scheduler import TLRScheduler
43
- from torchtnt .utils .prepare_module import DDPStrategy , FSDPStrategy
43
+ from torchtnt .utils .prepare_module import DDPStrategy , FSDPStrategy , TorchCompileParams
44
44
from torchtnt .utils .progress import Progress
45
45
from torchtnt .utils .swa import _AVERAGED_MODEL_AVAIL
46
46
from torchtnt .utils .test_utils import skip_if_not_distributed
@@ -741,6 +741,15 @@ def test_enable_prefetch(self) -> None:
741
741
_ = auto_unit ._get_next_batch (get_dummy_train_state (), iter (data ))
742
742
self .assertIsNone (auto_unit ._phase_to_next_batch [ActivePhase .TRAIN ])
743
743
744
+ def test_detect_anomaly_disabled_with_torch_compile (self ) -> None :
745
+ auto_unit = DummyAutoUnit (
746
+ module = torch .nn .Linear (2 , 2 ),
747
+ detect_anomaly = True ,
748
+ torch_compile_params = TorchCompileParams (),
749
+ )
750
+
751
+ self .assertIsNone (auto_unit .detect_anomaly )
752
+
744
753
745
754
Batch = Tuple [torch .Tensor , torch .Tensor ]
746
755
Original file line number Diff line number Diff line change @@ -183,6 +183,11 @@ def __init__(
183
183
)
184
184
185
185
self .detect_anomaly = detect_anomaly
186
+ if torch_compile_params is not None :
187
+ # torch compile is not compatible with detect anomaly
188
+ # so we disable detect anomaly if torch compile is enabled
189
+ self .detect_anomaly = None
190
+ _logger .warning ("torch.compile is enabled, so detect_anomaly is disabled" )
186
191
187
192
# create autocast context based on precision and device type
188
193
self .maybe_autocast_precision = torch .autocast (
You can’t perform that action at this time.
0 commit comments