Skip to content

Commit e9edb28

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
fail fast if unproper strategy passed in (#881)
Summary: Pull Request resolved: #881 tsia Reviewed By: diego-urgell Differential Revision: D61241468 fbshipit-source-id: 05e707ca1307961668a5b674e101e9b2165944d1
1 parent 0efa62b commit e9edb28

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

tests/utils/test_prepare_module.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,15 @@ def test_prepare_module_strategy_invalid_str(self) -> None:
7575
strategy="foo",
7676
)
7777

78+
def test_prepare_module_invalid_strategy(self) -> None:
79+
with self.assertRaisesRegex(ValueError, "Unknown strategy received"):
80+
prepare_module(
81+
module=torch.nn.Linear(2, 2),
82+
device=init_from_env(),
83+
# pyre-ignore: Incompatible parameter type [6] (intentional to test error raised)
84+
strategy={"_strategy_": "DDPStrategy"},
85+
)
86+
7887
def test_prepare_noop(self) -> None:
7988
device = torch.device("cuda") # Suppose init_from_env returns cuda
8089

torchtnt/utils/prepare_module.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,11 @@ def prepare_module(
312312
"""
313313

314314
if strategy:
315+
if not isinstance(strategy, str) and not isinstance(strategy, Strategy):
316+
raise ValueError(
317+
f"Unknown strategy received: {strategy}. Expect either str or Strategy dataclass"
318+
)
319+
315320
if isinstance(strategy, str):
316321
strategy = convert_str_to_strategy(strategy)
317322
if isinstance(strategy, DDPStrategy):

0 commit comments

Comments
 (0)