|
11 | 11 | from typing import Dict, Iterator
|
12 | 12 |
|
13 | 13 | import torch
|
14 |
| - |
15 |
| -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
16 | 14 | from torch.optim import Optimizer
|
17 | 15 | from torchtnt.framework._unit_utils import (
|
18 | 16 | _find_optimizers_for_module,
|
19 | 17 | _step_requires_iterator,
|
20 | 18 | )
|
21 | 19 | from torchtnt.framework.state import State
|
22 |
| -from torchtnt.utils.distributed import spawn_multi_process |
23 |
| -from torchtnt.utils.env import init_from_env |
24 |
| -from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu |
25 | 20 |
|
26 | 21 |
|
27 | 22 | class UnitUtilsTest(unittest.TestCase):
|
@@ -55,26 +50,3 @@ def test_find_optimizers_for_module(self) -> None:
|
55 | 50 | optimizers = _find_optimizers_for_module(module2, opts)
|
56 | 51 | optim_name, _ = optimizers[0]
|
57 | 52 | self.assertEqual(optim_name, "optim2")
|
58 |
| - |
59 |
| - @skip_if_not_distributed |
60 |
| - @skip_if_not_gpu |
61 |
| - def test_find_optimizers_for_FSDP_module(self) -> None: |
62 |
| - spawn_multi_process(2, "nccl", self._find_optimizers_for_FSDP_module) |
63 |
| - |
64 |
| - @staticmethod |
65 |
| - def _find_optimizers_for_FSDP_module() -> None: |
66 |
| - device = init_from_env() |
67 |
| - module1 = FSDP(torch.nn.Linear(10, 10).to(device)) |
68 |
| - module2 = torch.nn.Linear(10, 10) |
69 |
| - optim1 = torch.optim.Adam(module1.parameters()) |
70 |
| - optim2 = torch.optim.Adagrad(module2.parameters()) |
71 |
| - |
72 |
| - opts: Dict[str, Optimizer] = {"optim1": optim1, "optim2": optim2} |
73 |
| - optim_list = _find_optimizers_for_module(module1, opts) |
74 |
| - optim_name, _ = optim_list[0] |
75 |
| - |
76 |
| - tc = unittest.TestCase() |
77 |
| - tc.assertEqual(optim_name, "optim1") |
78 |
| - optim_list = _find_optimizers_for_module(module2, opts) |
79 |
| - optim_name, _ = optim_list[0] |
80 |
| - tc.assertEqual(optim_name, "optim2") |
0 commit comments