Skip to content

Commit 7737e13

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
make AveragedModel OSS compatible
Summary: Certain imports from pytorch AveragedModel are not in torch version 2.0.0, so this diff guards against potential import errors at runtime Reviewed By: galrotem Differential Revision: D56534614 fbshipit-source-id: fe6a5b56eabf51c101fa0ebc0f1eb2870df10c97
1 parent 1a47149 commit 7737e13

File tree

3 files changed

+34
-11
lines changed

3 files changed

+34
-11
lines changed

tests/framework/test_auto_unit.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from torchtnt.utils.lr_scheduler import TLRScheduler
4242
from torchtnt.utils.prepare_module import DDPStrategy
4343
from torchtnt.utils.progress import Progress
44+
from torchtnt.utils.swa import _AVERAGED_MODEL_AVAIL
4445
from torchtnt.utils.test_utils import skip_if_not_distributed
4546
from torchtnt.utils.timer import Timer
4647

@@ -149,6 +150,9 @@ def test_predict_step(self) -> None:
149150
predict(auto_unit, pred_dataloader, max_steps_per_epoch=1)
150151
mock_predict_step_end.assert_called_once()
151152

153+
@unittest.skipUnless(
154+
_AVERAGED_MODEL_AVAIL, "AveragedModel needed in version of Pytorch"
155+
)
152156
def test_stochastic_weight_averaging_basic(self) -> None:
153157
"""
154158
Basic stochastic weight averaging tests
@@ -182,6 +186,9 @@ def test_stochastic_weight_averaging_basic(self) -> None:
182186
self.assertIn("swa_scheduler", auto_unit2.app_state())
183187
self.assertIn("swa_scheduler", auto_unit2.tracked_lr_schedulers())
184188

189+
@unittest.skipUnless(
190+
_AVERAGED_MODEL_AVAIL, "AveragedModel needed in version of Pytorch"
191+
)
185192
def test_stochastic_weight_averaging_update_freq(self) -> None:
186193
"""
187194
e2e stochastic weight averaging test to ensure that the SWA model is updated at the correct frequency
@@ -295,11 +302,12 @@ def test_auto_unit_ddp(self) -> None:
295302
Launch tests of AutoUnit with DDP strategy
296303
"""
297304

298-
spawn_multi_process(
299-
2,
300-
"gloo",
301-
self._test_stochastic_weight_averaging_with_ddp,
302-
)
305+
if _AVERAGED_MODEL_AVAIL:
306+
spawn_multi_process(
307+
2,
308+
"gloo",
309+
self._test_stochastic_weight_averaging_with_ddp,
310+
)
303311
spawn_multi_process(
304312
2,
305313
"gloo",

tests/utils/test_swa.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313

1414
import torch
1515

16-
from torchtnt.utils.swa import AveragedModel
16+
from torchtnt.utils.swa import _AVERAGED_MODEL_AVAIL, AveragedModel
17+
18+
if not _AVERAGED_MODEL_AVAIL:
19+
raise unittest.SkipTest("Latest Pytorch is required to run SWA tests")
1720

1821

1922
class TestSWA(unittest.TestCase):

torchtnt/utils/swa.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,17 @@
1010

1111
import torch
1212

13-
from torch.optim.swa_utils import (
14-
AveragedModel as PyTorchAveragedModel,
15-
get_ema_multi_avg_fn,
16-
get_swa_multi_avg_fn,
17-
)
13+
_AVERAGED_MODEL_AVAIL: bool = True
14+
15+
try:
16+
from torch.optim.swa_utils import (
17+
AveragedModel as PyTorchAveragedModel,
18+
get_ema_multi_avg_fn,
19+
get_swa_multi_avg_fn,
20+
)
21+
except ImportError:
22+
_AVERAGED_MODEL_AVAIL = False
23+
1824

1925
TSWA_avg_fn = Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]
2026
TSWA_multi_avg_fn = Callable[[List[torch.Tensor], List[torch.Tensor], int], None]
@@ -49,6 +55,12 @@ def __init__(
4955
number of updates. The EMA decay will start small and will approach the
5056
specified ema_decay as more updates occur.
5157
"""
58+
if not _AVERAGED_MODEL_AVAIL:
59+
raise ImportError(
60+
"AveragedModel is not available in this version of PyTorch. \
61+
Please install the latest version of PyTorch."
62+
)
63+
5264
# setup averaging method
5365
if averaging_method == "ema":
5466
if ema_decay < 0.0 or ema_decay > 1.0:

0 commit comments

Comments
 (0)