Skip to content

Commit fcd8b22

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
add and use generic torch version comparator
Summary: # Context Before we would have an individual function for each version of pytorch we want to check for. Let's remove this in favor of a single universal function # This Diff Pass the version as a string instead and replaces old uses Reviewed By: galrotem Differential Revision: D56446382 fbshipit-source-id: 0f8b35cc667603394c5286fd35495d7c5ab5265b
1 parent 3dafb0b commit fcd8b22

File tree

5 files changed

+15
-17
lines changed

5 files changed

+15
-17
lines changed

tests/utils/test_prepare_module.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222
TorchCompileParams,
2323
)
2424
from torchtnt.utils.test_utils import skip_if_not_distributed
25-
from torchtnt.utils.version import Version
25+
from torchtnt.utils.version import is_torch_version_geq
2626

2727

2828
class PrepareModelTest(unittest.TestCase):
29+
torch_version_geq_2_1_0: bool = is_torch_version_geq("2.1.0")
30+
2931
def test_invalid_fsdp_strategy_str_values(self) -> None:
3032
from torchtnt.utils.prepare_module import MixedPrecision as _MixedPrecision
3133

@@ -144,7 +146,7 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No
144146

145147
tc = unittest.TestCase()
146148
with patch(
147-
"torchtnt.utils.version.get_torch_version", return_value=Version("2.0.0")
149+
"torchtnt.utils.prepare_module.is_torch_version_geq", return_value=False
148150
):
149151
with tc.assertRaisesRegex(
150152
RuntimeError,
@@ -157,14 +159,6 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No
157159
torch_compile_params=TorchCompileParams(backend="inductor"),
158160
)
159161

160-
# no error should be thrown on latest pytorch
161-
prepare_module(
162-
module=torch.nn.Linear(2, 2),
163-
device=init_from_env(),
164-
strategy=DDPStrategy(static_graph=True),
165-
torch_compile_params=TorchCompileParams(backend="inductor"),
166-
)
167-
168162
def test_prepare_module_compile_invalid_backend(self) -> None:
169163
"""
170164
verify error is thrown on invalid backend
@@ -190,6 +184,10 @@ def test_prepare_module_incompatible_FSDP_torchcompile_params(self) -> None:
190184
torch_compile_params=TorchCompileParams(),
191185
)
192186

187+
@unittest.skipUnless(
188+
torch_version_geq_2_1_0,
189+
reason="Must be on torch 2.1.0+ to run test",
190+
)
193191
def test_prepare_module_compile_module_state_dict(self) -> None:
194192
device = init_from_env()
195193
my_module = torch.nn.Linear(2, 2, device=device)

tests/utils/test_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ def test_get_torch_version(self) -> None:
4949

5050
def test_torch_version_comparators(self) -> None:
5151
with patch.object(torch, "__version__", "2.0.0a0"):
52-
self.assertFalse(version.is_torch_version_geq_2_1())
52+
self.assertFalse(version.is_torch_version_geq("2.1.0"))

torchtnt/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
from .version import (
7575
get_python_version,
7676
get_torch_version,
77-
is_torch_version_geq_2_1,
77+
is_torch_version_geq,
7878
is_windows,
7979
)
8080

@@ -144,7 +144,7 @@
144144
"TLRScheduler",
145145
"get_python_version",
146146
"get_torch_version",
147-
"is_torch_version_geq_2_1",
147+
"is_torch_version_geq",
148148
"is_windows",
149149
"get_pet_launch_config",
150150
"spawn_multi_process",

torchtnt/utils/prepare_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242

4343
from torchtnt.utils.rank_zero_log import rank_zero_warn
44-
from torchtnt.utils.version import is_torch_version_geq_2_1
44+
from torchtnt.utils.version import is_torch_version_geq
4545

4646

4747
@dataclass
@@ -318,7 +318,7 @@ def prepare_module(
318318
if (
319319
torch_compile_params
320320
and strategy.static_graph is True
321-
and not is_torch_version_geq_2_1()
321+
and not is_torch_version_geq("2.1.0")
322322
):
323323
raise RuntimeError(
324324
"Torch version >= 2.1.0 required for Torch compile + DDP with static graph"

torchtnt/utils/version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,5 @@ def get_torch_version() -> Version:
5656
return pkg_version
5757

5858

59-
def is_torch_version_geq_2_1() -> bool:
60-
return get_torch_version() >= Version("2.1.0")
59+
def is_torch_version_geq(version: str) -> bool:
60+
return get_torch_version() >= Version(version)

0 commit comments

Comments
 (0)