Skip to content

Commit 3dafb0b

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
remove <= 2.0 version checks
Summary: # Context Part of effort to make TorchTNT OSS compliant. TorchTNT uses latest pytorch apis. Sometimes these are only available in nightlies. When that's the case, users in OSS using a stable pytorch release will see import errors when they use TorchTNT. This can be guarded against by importing conditionally, only when the compatible pytorch version is detected in the environment. We have plenty of these checks already. However, at the same time we don't want to bload TNT with too many of these version checks everywhere. Currently we have various version checks for Pytorch 1.0. I propose to make Pytorch 2.0+ a hard dependency for TorchTNT going forward. This will * remove existing version check bloat in TNT * force users to use latest features from pytorch And Pytorch 2.0 is documented to be 100% backwards compatible, so no bugs or errors should show up # This Diff Removes all the Pytorch version checks below 2.0 in various places of the codebase, and all the helper functions which check for version Reviewed By: galrotem Differential Revision: D56446353 fbshipit-source-id: 2a594e21fa755f249f0f40e352aa90e4476d83ca
1 parent ad3ff86 commit 3dafb0b

17 files changed

+44
-293
lines changed

tests/framework/callbacks/test_module_summary.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@
1919

2020
from torchtnt.framework.callbacks.module_summary import ModuleSummary
2121
from torchtnt.framework.state import EntryPoint, PhaseState, State
22-
from torchtnt.utils.version import is_torch_version_geq_1_13
23-
24-
MODULE_SUMMARY_FLOPS_AVAILABLE = False
25-
if is_torch_version_geq_1_13():
26-
MODULE_SUMMARY_FLOPS_AVAILABLE = True
2722

2823

2924
class ModuleSummaryTest(unittest.TestCase):
@@ -85,10 +80,6 @@ def forward(self, x):
8580
self.assertTrue("b1" in ms.submodule_summaries)
8681
self.assertTrue("l2" in ms.submodule_summaries)
8782

88-
@unittest.skipUnless(
89-
condition=MODULE_SUMMARY_FLOPS_AVAILABLE,
90-
reason="This test needs PyTorch 1.13 or greater to run.",
91-
)
9283
def test_module_summary_retrieve_module_summaries_module_inputs(self) -> None:
9384
"""
9485
Test ModuleSummary callback in train

tests/framework/test_auto_unit.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,6 @@
1212
from unittest.mock import MagicMock, patch
1313

1414
import torch
15-
from torchtnt.framework.auto_unit import TrainStepResults
16-
from torchtnt.utils.test_utils import skip_if_not_distributed
17-
18-
from torchtnt.utils.version import is_torch_version_geq_1_13
19-
20-
COMPILE_AVAIL = False
21-
if is_torch_version_geq_1_13():
22-
COMPILE_AVAIL = True
23-
import torch._dynamo
2415

2516
from pyre_extensions import none_throws, ParameterSpecification as ParamSpec
2617

@@ -37,6 +28,7 @@
3728
AutoUnit,
3829
SWALRParams,
3930
SWAParams,
31+
TrainStepResults,
4032
)
4133
from torchtnt.framework.evaluate import evaluate
4234
from torchtnt.framework.predict import predict
@@ -49,6 +41,7 @@
4941
from torchtnt.utils.lr_scheduler import TLRScheduler
5042
from torchtnt.utils.prepare_module import DDPStrategy
5143
from torchtnt.utils.progress import Progress
44+
from torchtnt.utils.test_utils import skip_if_not_distributed
5245
from torchtnt.utils.timer import Timer
5346

5447
TParams = ParamSpec("TParams")

tests/framework/test_auto_unit_gpu.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,16 @@
88
# pyre-strict
99

1010
import unittest
11+
12+
from copy import deepcopy
1113
from typing import TypeVar
1214
from unittest.mock import MagicMock, patch
1315

1416
import torch
15-
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
16-
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
17-
18-
from torchtnt.utils.version import is_torch_version_geq_1_13
19-
20-
COMPILE_AVAIL = False
21-
if is_torch_version_geq_1_13():
22-
COMPILE_AVAIL = True
23-
import torch._dynamo
24-
25-
from copy import deepcopy
2617

2718
from pyre_extensions import ParameterSpecification as ParamSpec
2819
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
20+
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
2921
from torchtnt.framework._test_utils import (
3022
DummyAutoUnit,
3123
generate_random_dataloader,
@@ -40,6 +32,7 @@
4032
from torchtnt.utils.distributed import spawn_multi_process
4133
from torchtnt.utils.env import init_from_env, seed
4234
from torchtnt.utils.prepare_module import DDPStrategy, FSDPStrategy, TorchCompileParams
35+
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
4336

4437
TParams = ParamSpec("TParams")
4538
T = TypeVar("T")
@@ -320,10 +313,6 @@ def test_predict_mixed_precision_fp16(self, mock_autocast: MagicMock) -> None:
320313
device_type="cuda", dtype=torch.float16, enabled=True
321314
)
322315

323-
@unittest.skipUnless(
324-
condition=COMPILE_AVAIL,
325-
reason="This test needs PyTorch 1.13 or greater to run.",
326-
)
327316
@skip_if_not_gpu
328317
@patch("torch.compile")
329318
def test_compile_predict(self, mock_dynamo: MagicMock) -> None:

tests/utils/test_memory_snapshot_profiler.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,9 @@
1414
MemorySnapshotParams,
1515
MemorySnapshotProfiler,
1616
)
17-
from torchtnt.utils.version import is_torch_version_geq_2_0
1817

1918

2019
class MemorySnapshotProfilerTest(unittest.TestCase):
21-
22-
torch_version_geq_2_0: bool = is_torch_version_geq_2_0()
23-
24-
@unittest.skipUnless(
25-
condition=torch_version_geq_2_0,
26-
reason="This test needs changes from PyTorch 2.0 to run.",
27-
)
2820
def test_validation(self) -> None:
2921
"""Test parameter validation."""
3022
with tempfile.TemporaryDirectory() as temp_dir:

tests/utils/test_memory_snapshot_profiler_gpu.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,10 @@
1818
MemorySnapshotProfiler,
1919
)
2020
from torchtnt.utils.test_utils import skip_if_not_gpu
21-
from torchtnt.utils.version import is_torch_version_geq_2_0
2221

2322

2423
class MemorySnapshotProfilerGPUTest(unittest.TestCase):
25-
26-
torch_version_geq_2_0: bool = is_torch_version_geq_2_0()
27-
2824
@skip_if_not_gpu
29-
@unittest.skipUnless(
30-
condition=torch_version_geq_2_0,
31-
reason="This test needs changes from PyTorch 2.0 to run.",
32-
)
3325
def test_stop_step(self) -> None:
3426
"""Test that a memory snapshot is saved when stop_step is reached."""
3527
with tempfile.TemporaryDirectory() as temp_dir:

tests/utils/test_oom_gpu.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,10 @@
1616
from torchtnt.utils.oom import log_memory_snapshot
1717

1818
from torchtnt.utils.test_utils import skip_if_not_gpu
19-
from torchtnt.utils.version import is_torch_version_geq_2_0
2019

2120

2221
class OomGPUTest(unittest.TestCase):
2322
@skip_if_not_gpu
24-
@unittest.skipUnless(
25-
condition=bool(is_torch_version_geq_2_0()),
26-
reason="This test needs changes from PyTorch 2.0 to run.",
27-
)
2823
def test_log_memory_snapshot(self) -> None:
2924
with tempfile.TemporaryDirectory() as temp_dir:
3025
# Record history

tests/utils/test_prepare_module.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,7 @@
2222
TorchCompileParams,
2323
)
2424
from torchtnt.utils.test_utils import skip_if_not_distributed
25-
from torchtnt.utils.version import is_torch_version_geq_1_13, Version
26-
27-
COMPILE_AVAIL = False
28-
if is_torch_version_geq_1_13():
29-
COMPILE_AVAIL = True
30-
import torch._dynamo
25+
from torchtnt.utils.version import Version
3126

3227

3328
class PrepareModelTest(unittest.TestCase):
@@ -170,10 +165,6 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No
170165
torch_compile_params=TorchCompileParams(backend="inductor"),
171166
)
172167

173-
@unittest.skipUnless(
174-
condition=COMPILE_AVAIL,
175-
reason="This test needs PyTorch 1.13 or greater to run.",
176-
)
177168
def test_prepare_module_compile_invalid_backend(self) -> None:
178169
"""
179170
verify error is thrown on invalid backend
@@ -199,10 +190,6 @@ def test_prepare_module_incompatible_FSDP_torchcompile_params(self) -> None:
199190
torch_compile_params=TorchCompileParams(),
200191
)
201192

202-
@unittest.skipUnless(
203-
condition=COMPILE_AVAIL,
204-
reason="This test needs PyTorch 1.13 or greater to run.",
205-
)
206193
def test_prepare_module_compile_module_state_dict(self) -> None:
207194
device = init_from_env()
208195
my_module = torch.nn.Linear(2, 2, device=device)

tests/utils/test_prepare_module_gpu.py

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
# pyre-strict
99
import unittest
10-
from unittest.mock import patch
1110

1211
import torch
12+
13+
from torch.distributed._composable import fully_shard
1314
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1415
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
1516
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -24,15 +25,6 @@
2425
prepare_module,
2526
)
2627
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
27-
from torchtnt.utils.version import is_torch_version_geq_1_13, is_torch_version_geq_2_0
28-
29-
COMPILE_AVAIL = False
30-
if is_torch_version_geq_1_13():
31-
COMPILE_AVAIL = True
32-
import torch._dynamo
33-
34-
if is_torch_version_geq_2_0():
35-
from torch.distributed._composable import fully_shard
3628

3729

3830
class PrepareModelGPUTest(unittest.TestCase):
@@ -81,33 +73,6 @@ def _test_prepare_fsdp() -> None:
8173
tc = unittest.TestCase()
8274
tc.assertTrue(isinstance(fsdp_module, FSDP))
8375

84-
@skip_if_not_distributed
85-
@skip_if_not_gpu
86-
def test_fsdp_pytorch_version(self) -> None:
87-
"""
88-
Test that a RuntimeError is thrown when using FSDP, and PyTorch < v1.12
89-
"""
90-
spawn_multi_process(
91-
2,
92-
"nccl",
93-
self._test_fsdp_pytorch_version,
94-
)
95-
96-
@staticmethod
97-
def _test_fsdp_pytorch_version() -> None:
98-
device = init_from_env()
99-
module = torch.nn.Linear(2, 2).to(device)
100-
101-
tc = unittest.TestCase()
102-
with patch(
103-
"torchtnt.utils.prepare_module.is_torch_version_geq_1_12",
104-
return_value=False,
105-
), tc.assertRaisesRegex(
106-
RuntimeError,
107-
"Please install PyTorch 1.12 or higher to use FSDP: https://pytorch.org/get-started/locally/",
108-
):
109-
_ = prepare_fsdp(module, device, FSDPStrategy())
110-
11176
@skip_if_not_distributed
11277
@unittest.skipUnless(
11378
condition=bool(torch.cuda.device_count() >= 2),
@@ -128,9 +93,8 @@ def _test_is_fsdp_module() -> None:
12893
model = FSDP(torch.nn.Linear(1, 1, device=device))
12994
assert _is_fsdp_module(model)
13095
model = torch.nn.Linear(1, 1, device=device)
131-
if is_torch_version_geq_2_0():
132-
fully_shard(model)
133-
assert _is_fsdp_module(model)
96+
fully_shard(model)
97+
assert _is_fsdp_module(model)
13498

13599
@skip_if_not_distributed
136100
@skip_if_not_gpu

tests/utils/test_version.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -48,48 +48,5 @@ def test_get_torch_version(self) -> None:
4848
self.assertEqual(version.get_torch_version(), Version("1.12.0"))
4949

5050
def test_torch_version_comparators(self) -> None:
51-
with patch.object(torch, "__version__", "1.7.0"):
52-
self.assertFalse(version.is_torch_version_geq_1_8())
53-
self.assertFalse(version.is_torch_version_geq_1_9())
54-
self.assertFalse(version.is_torch_version_geq_1_10())
55-
self.assertFalse(version.is_torch_version_geq_1_11())
56-
self.assertFalse(version.is_torch_version_geq_1_12())
57-
58-
with patch.object(torch, "__version__", "1.8.0"):
59-
self.assertTrue(version.is_torch_version_geq_1_8())
60-
self.assertFalse(version.is_torch_version_geq_1_9())
61-
self.assertFalse(version.is_torch_version_geq_1_10())
62-
self.assertFalse(version.is_torch_version_geq_1_11())
63-
self.assertFalse(version.is_torch_version_geq_1_12())
64-
65-
with patch.object(torch, "__version__", "1.9.0"):
66-
self.assertTrue(version.is_torch_version_geq_1_8())
67-
self.assertTrue(version.is_torch_version_geq_1_9())
68-
self.assertFalse(version.is_torch_version_geq_1_10())
69-
self.assertFalse(version.is_torch_version_geq_1_11())
70-
self.assertFalse(version.is_torch_version_geq_1_12())
71-
72-
with patch.object(torch, "__version__", "1.10.0"):
73-
self.assertTrue(version.is_torch_version_geq_1_8())
74-
self.assertTrue(version.is_torch_version_geq_1_9())
75-
self.assertTrue(version.is_torch_version_geq_1_10())
76-
self.assertFalse(version.is_torch_version_geq_1_11())
77-
self.assertFalse(version.is_torch_version_geq_1_12())
78-
79-
with patch.object(torch, "__version__", "1.11.0"):
80-
self.assertTrue(version.is_torch_version_geq_1_8())
81-
self.assertTrue(version.is_torch_version_geq_1_9())
82-
self.assertTrue(version.is_torch_version_geq_1_10())
83-
self.assertTrue(version.is_torch_version_geq_1_11())
84-
self.assertFalse(version.is_torch_version_geq_1_12())
85-
86-
with patch.object(torch, "__version__", "1.12.0"):
87-
self.assertTrue(version.is_torch_version_geq_1_8())
88-
self.assertTrue(version.is_torch_version_geq_1_9())
89-
self.assertTrue(version.is_torch_version_geq_1_10())
90-
self.assertTrue(version.is_torch_version_geq_1_11())
91-
self.assertTrue(version.is_torch_version_geq_1_12())
92-
9351
with patch.object(torch, "__version__", "2.0.0a0"):
94-
self.assertTrue(version.is_torch_version_ge_1_13_1())
95-
self.assertFalse(version.is_torch_version_geq_2_0())
52+
self.assertFalse(version.is_torch_version_geq_2_1())

torchtnt/framework/auto_unit.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
TorchCompileParams,
5151
)
5252
from torchtnt.utils.swa import AveragedModel
53-
from torchtnt.utils.version import is_torch_version_ge_1_13_1
5453
from typing_extensions import Literal
5554

5655

@@ -166,8 +165,6 @@ def __init__(
166165
torch_compile_params: Optional[TorchCompileParams] = None,
167166
) -> None:
168167
super().__init__()
169-
if torch_compile_params:
170-
_validate_torch_compile_available()
171168

172169
self.device: torch.device = device or init_from_env()
173170
self.precision: Optional[torch.dtype] = (
@@ -879,11 +876,3 @@ def _update_lr_and_swa(self, state: State, number_of_steps_or_epochs: int) -> No
879876
state, f"{self.__class__.__name__}.lr_scheduler_step"
880877
):
881878
self.step_lr_scheduler()
882-
883-
884-
def _validate_torch_compile_available() -> None:
885-
if not is_torch_version_ge_1_13_1():
886-
raise RuntimeError(
887-
"Torch compile support is available only in PyTorch 2.0 or higher. "
888-
"Please install PyTorch 2.0 or higher to continue: https://pytorch.org/get-started/locally/"
889-
)

0 commit comments

Comments
 (0)