Skip to content

Commit c2dcee9

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Move unit_utils GPU test to specfic file (#756)
Summary: Pull Request resolved: #756 Reviewed By: JKSenthil Differential Revision: D55257789 fbshipit-source-id: 645906442f243fcc872965c80e8d7fcc2e229fbe
1 parent e806af0 commit c2dcee9

File tree

2 files changed

+45
-28
lines changed

2 files changed

+45
-28
lines changed

tests/framework/test_unit_utils.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,12 @@
1111
from typing import Dict, Iterator
1212

1313
import torch
14-
15-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1614
from torch.optim import Optimizer
1715
from torchtnt.framework._unit_utils import (
1816
_find_optimizers_for_module,
1917
_step_requires_iterator,
2018
)
2119
from torchtnt.framework.state import State
22-
from torchtnt.utils.distributed import spawn_multi_process
23-
from torchtnt.utils.env import init_from_env
24-
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
2520

2621

2722
class UnitUtilsTest(unittest.TestCase):
@@ -55,26 +50,3 @@ def test_find_optimizers_for_module(self) -> None:
5550
optimizers = _find_optimizers_for_module(module2, opts)
5651
optim_name, _ = optimizers[0]
5752
self.assertEqual(optim_name, "optim2")
58-
59-
@skip_if_not_distributed
60-
@skip_if_not_gpu
61-
def test_find_optimizers_for_FSDP_module(self) -> None:
62-
spawn_multi_process(2, "nccl", self._find_optimizers_for_FSDP_module)
63-
64-
@staticmethod
65-
def _find_optimizers_for_FSDP_module() -> None:
66-
device = init_from_env()
67-
module1 = FSDP(torch.nn.Linear(10, 10).to(device))
68-
module2 = torch.nn.Linear(10, 10)
69-
optim1 = torch.optim.Adam(module1.parameters())
70-
optim2 = torch.optim.Adagrad(module2.parameters())
71-
72-
opts: Dict[str, Optimizer] = {"optim1": optim1, "optim2": optim2}
73-
optim_list = _find_optimizers_for_module(module1, opts)
74-
optim_name, _ = optim_list[0]
75-
76-
tc = unittest.TestCase()
77-
tc.assertEqual(optim_name, "optim1")
78-
optim_list = _find_optimizers_for_module(module2, opts)
79-
optim_name, _ = optim_list[0]
80-
tc.assertEqual(optim_name, "optim2")
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import unittest
11+
from typing import Dict
12+
13+
import torch
14+
15+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
16+
from torch.optim import Optimizer
17+
from torchtnt.framework._unit_utils import _find_optimizers_for_module
18+
from torchtnt.utils.distributed import spawn_multi_process
19+
from torchtnt.utils.env import init_from_env
20+
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
21+
22+
23+
class UnitUtilsGPUTest(unittest.TestCase):
24+
@skip_if_not_distributed
25+
@skip_if_not_gpu
26+
def test_find_optimizers_for_FSDP_module(self) -> None:
27+
spawn_multi_process(2, "nccl", self._find_optimizers_for_FSDP_module)
28+
29+
@staticmethod
30+
def _find_optimizers_for_FSDP_module() -> None:
31+
device = init_from_env()
32+
module1 = FSDP(torch.nn.Linear(10, 10).to(device))
33+
module2 = torch.nn.Linear(10, 10)
34+
optim1 = torch.optim.Adam(module1.parameters())
35+
optim2 = torch.optim.Adagrad(module2.parameters())
36+
37+
opts: Dict[str, Optimizer] = {"optim1": optim1, "optim2": optim2}
38+
optim_list = _find_optimizers_for_module(module1, opts)
39+
optim_name, _ = optim_list[0]
40+
41+
tc = unittest.TestCase()
42+
tc.assertEqual(optim_name, "optim1")
43+
optim_list = _find_optimizers_for_module(module2, opts)
44+
optim_name, _ = optim_list[0]
45+
tc.assertEqual(optim_name, "optim2")

0 commit comments

Comments
 (0)