Skip to content

Commit 51e1485

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Split prepare_module GPU tests into separate file (#767)
Summary: Pull Request resolved: #767 Reviewed By: galrotem Differential Revision: D55452262 fbshipit-source-id: 20b168499edefcefbeb947f5035ba5385b41c725
1 parent 19f2911 commit 51e1485

File tree

2 files changed

+331
-306
lines changed

2 files changed

+331
-306
lines changed

tests/utils/test_prepare_module.py

Lines changed: 32 additions & 306 deletions
Original file line numberDiff line numberDiff line change
@@ -11,263 +11,26 @@
1111
from unittest.mock import patch
1212

1313
import torch
14-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
15-
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
1614
from torch.nn.parallel import DistributedDataParallel as DDP
1715
from torchtnt.utils.distributed import spawn_multi_process
1816
from torchtnt.utils.env import init_from_env
1917
from torchtnt.utils.prepare_module import (
20-
_is_fsdp_module,
2118
DDPStrategy,
2219
FSDPStrategy,
2320
NOOPStrategy,
24-
prepare_ddp,
25-
prepare_fsdp,
2621
prepare_module,
2722
TorchCompileParams,
2823
)
29-
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
30-
from torchtnt.utils.version import (
31-
is_torch_version_geq_1_13,
32-
is_torch_version_geq_2_0,
33-
Version,
34-
)
24+
from torchtnt.utils.test_utils import skip_if_not_distributed
25+
from torchtnt.utils.version import is_torch_version_geq_1_13, Version
3526

3627
COMPILE_AVAIL = False
3728
if is_torch_version_geq_1_13():
3829
COMPILE_AVAIL = True
3930
import torch._dynamo
4031

41-
if is_torch_version_geq_2_0():
42-
from torch.distributed._composable import fully_shard
43-
4432

4533
class PrepareModelTest(unittest.TestCase):
46-
@skip_if_not_gpu
47-
def test_prepare_no_strategy(self) -> None:
48-
module = torch.nn.Linear(2, 2) # initialize on cpu
49-
device = init_from_env() # should be cuda device
50-
module = prepare_module(module, device, strategy=None)
51-
self.assertEqual(next(module.parameters()).device, device)
52-
53-
@skip_if_not_gpu
54-
def test_prepare_noop(self) -> None:
55-
module = torch.nn.Linear(2, 2) # initialize on cpu
56-
device = init_from_env() # should be cuda device
57-
module = prepare_module(module, device, strategy=NOOPStrategy())
58-
self.assertNotEqual(next(module.parameters()).device, device)
59-
60-
module2 = torch.nn.Linear(2, 2) # initialize on cpu
61-
module2 = prepare_module(module2, device, strategy="noop")
62-
self.assertNotEqual(next(module2.parameters()).device, device)
63-
64-
@skip_if_not_gpu
65-
@skip_if_not_distributed
66-
def test_prepare_ddp(self) -> None:
67-
spawn_multi_process(
68-
2,
69-
"nccl",
70-
self._test_prepare_ddp,
71-
)
72-
73-
@staticmethod
74-
def _test_prepare_ddp() -> None:
75-
module = torch.nn.Linear(2, 2)
76-
device = init_from_env()
77-
ddp_module = prepare_ddp(
78-
module,
79-
device,
80-
DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True),
81-
)
82-
tc = unittest.TestCase()
83-
tc.assertTrue(isinstance(ddp_module, DDP))
84-
85-
@skip_if_not_gpu
86-
@skip_if_not_distributed
87-
def test_prepare_fsdp(self) -> None:
88-
spawn_multi_process(
89-
2,
90-
"nccl",
91-
self._test_prepare_fsdp,
92-
)
93-
94-
@staticmethod
95-
def _test_prepare_fsdp() -> None:
96-
module = torch.nn.Linear(2, 2)
97-
device = init_from_env()
98-
fsdp_module = prepare_fsdp(module, device, FSDPStrategy(limit_all_gathers=True))
99-
tc = unittest.TestCase()
100-
tc.assertTrue(isinstance(fsdp_module, FSDP))
101-
102-
@skip_if_not_distributed
103-
@skip_if_not_gpu
104-
def test_fsdp_pytorch_version(self) -> None:
105-
"""
106-
Test that a RuntimeError is thrown when using FSDP, and PyTorch < v1.12
107-
"""
108-
spawn_multi_process(
109-
2,
110-
"nccl",
111-
self._test_fsdp_pytorch_version,
112-
)
113-
114-
@staticmethod
115-
def _test_fsdp_pytorch_version() -> None:
116-
device = init_from_env()
117-
module = torch.nn.Linear(2, 2).to(device)
118-
119-
tc = unittest.TestCase()
120-
with patch(
121-
"torchtnt.utils.prepare_module.is_torch_version_geq_1_12",
122-
return_value=False,
123-
), tc.assertRaisesRegex(
124-
RuntimeError,
125-
"Please install PyTorch 1.12 or higher to use FSDP: https://pytorch.org/get-started/locally/",
126-
):
127-
_ = prepare_fsdp(module, device, FSDPStrategy())
128-
129-
@staticmethod
130-
def _test_is_fsdp_module() -> None:
131-
model = torch.nn.Linear(1, 1)
132-
assert not _is_fsdp_module(model)
133-
model = FSDP(torch.nn.Linear(1, 1))
134-
assert _is_fsdp_module(model)
135-
model = torch.nn.Linear(1, 1)
136-
if is_torch_version_geq_2_0():
137-
fully_shard(model)
138-
assert _is_fsdp_module(model)
139-
140-
@skip_if_not_distributed
141-
@unittest.skipUnless(
142-
condition=bool(torch.cuda.device_count() >= 2),
143-
reason="This test needs 2 GPUs to run.",
144-
)
145-
def test_is_fsdp_module(self) -> None:
146-
spawn_multi_process(
147-
2,
148-
"gloo",
149-
self._test_is_fsdp_module,
150-
)
151-
152-
@skip_if_not_distributed
153-
@skip_if_not_gpu
154-
def test_fdsp_precision(self) -> None:
155-
spawn_multi_process(
156-
2,
157-
"nccl",
158-
self._test_fdsp_precision,
159-
)
160-
161-
@staticmethod
162-
def _test_fdsp_precision() -> None:
163-
module = torch.nn.Linear(1, 1)
164-
device = init_from_env()
165-
mixed_precision = MixedPrecision(
166-
param_dtype=torch.float64,
167-
)
168-
fsdp_module = prepare_fsdp(
169-
module, device, FSDPStrategy(mixed_precision=mixed_precision)
170-
)
171-
tc = unittest.TestCase()
172-
tc.assertTrue(isinstance(fsdp_module, FSDP))
173-
tc.assertEqual(
174-
fsdp_module.mixed_precision.param_dtype, mixed_precision.param_dtype
175-
)
176-
177-
@skip_if_not_distributed
178-
@skip_if_not_gpu
179-
@unittest.skipUnless(
180-
condition=bool(torch.cuda.device_count() >= 2),
181-
reason="This test needs 2 GPUs to run.",
182-
)
183-
def test_fdsp_str_types(self) -> None:
184-
spawn_multi_process(
185-
2,
186-
"nccl",
187-
self._test_fdsp_precision_str_types,
188-
)
189-
spawn_multi_process(
190-
2,
191-
"nccl",
192-
self._test_fdsp_backward_prefetch_str_types,
193-
)
194-
spawn_multi_process(
195-
2,
196-
"nccl",
197-
self._test_fdsp_sharding_strategy_str_types,
198-
)
199-
spawn_multi_process(
200-
2,
201-
"nccl",
202-
self._test_fdsp_state_dict_str_types,
203-
)
204-
205-
@staticmethod
206-
def _test_fdsp_precision_str_types() -> None:
207-
from torchtnt.utils.prepare_module import MixedPrecision as _MixedPrecision
208-
209-
module = torch.nn.Linear(1, 1)
210-
device = init_from_env()
211-
mixed_precision = _MixedPrecision(
212-
param_dtype="fp16",
213-
reduce_dtype="bf16",
214-
buffer_dtype="fp32",
215-
)
216-
217-
fsdp_module = prepare_fsdp(
218-
module, device, FSDPStrategy(mixed_precision=mixed_precision)
219-
)
220-
tc = unittest.TestCase()
221-
tc.assertTrue(isinstance(fsdp_module, FSDP))
222-
223-
@staticmethod
224-
def _test_fdsp_backward_prefetch_str_types() -> None:
225-
module = torch.nn.Linear(1, 1)
226-
device = init_from_env()
227-
228-
tc = unittest.TestCase()
229-
for value in ["BACKWARD_PRE", "BACKWARD_POST"]:
230-
fsdp_module = prepare_fsdp(
231-
module, device, FSDPStrategy(backward_prefetch=value)
232-
)
233-
tc.assertTrue(isinstance(fsdp_module, FSDP), f"tested value: {value}")
234-
235-
@staticmethod
236-
def _test_fdsp_sharding_strategy_str_types() -> None:
237-
module = torch.nn.Linear(1, 1)
238-
device = init_from_env()
239-
240-
tc = unittest.TestCase()
241-
for value in [
242-
"FULL_SHARD",
243-
"SHARD_GRAD_OP",
244-
"NO_SHARD",
245-
# skip hybrid strategy; tricky to configure in-test
246-
]:
247-
248-
fsdp_module = prepare_fsdp(
249-
module,
250-
device,
251-
FSDPStrategy(sharding_strategy=value),
252-
)
253-
tc.assertTrue(isinstance(fsdp_module, FSDP), f"tested value: {value}")
254-
255-
@staticmethod
256-
def _test_fdsp_state_dict_str_types() -> None:
257-
module = torch.nn.Linear(1, 1)
258-
device = init_from_env()
259-
260-
tc = unittest.TestCase()
261-
for value in [
262-
"FULL_STATE_DICT",
263-
"LOCAL_STATE_DICT",
264-
"SHARDED_STATE_DICT",
265-
]:
266-
fsdp_module = prepare_fsdp(
267-
module, device, FSDPStrategy(state_dict_type=value)
268-
)
269-
tc.assertTrue(isinstance(fsdp_module, FSDP), f"tested value: {value}")
270-
27134
def test_invalid_fsdp_strategy_str_values(self) -> None:
27235
from torchtnt.utils.prepare_module import MixedPrecision as _MixedPrecision
27336

@@ -315,52 +78,16 @@ def test_prepare_module_strategy_invalid_str(self) -> None:
31578
strategy="foo",
31679
)
31780

318-
@skip_if_not_distributed
319-
@skip_if_not_gpu
320-
def test_prepare_module_with_fsdp(self) -> None:
321-
"""
322-
Launch tests of FSDP strategy
323-
"""
324-
spawn_multi_process(
325-
2,
326-
"nccl",
327-
self._test_prepare_module_fsdp_strategy_wrapped_in_fsdp,
328-
)
329-
spawn_multi_process(
330-
2,
331-
"nccl",
332-
self._test_prepare_module_fsdp_string_wrapped_in_fsdp,
333-
)
334-
335-
@staticmethod
336-
def _test_prepare_module_fsdp_strategy_wrapped_in_fsdp() -> None:
337-
"""
338-
Test that the module is correctly wrapped in FSDP
339-
"""
340-
341-
fsdp_module = prepare_module(
342-
module=torch.nn.Linear(2, 2),
343-
device=init_from_env(),
344-
strategy=FSDPStrategy(),
345-
)
346-
tc = unittest.TestCase()
347-
348-
tc.assertTrue(isinstance(fsdp_module, FSDP))
349-
350-
@staticmethod
351-
def _test_prepare_module_fsdp_string_wrapped_in_fsdp() -> None:
352-
"""
353-
Test that the module is correctly wrapped in FSDP when passing "fsdp" as a string
354-
"""
81+
def test_prepare_noop(self) -> None:
82+
device = torch.device("cuda") # Suppose init_from_env returns cuda
35583

356-
fsdp_module = prepare_module(
357-
module=torch.nn.Linear(2, 2),
358-
device=init_from_env(),
359-
strategy="fsdp",
360-
)
361-
tc = unittest.TestCase()
84+
module = torch.nn.Linear(2, 2) # initialize on cpu
85+
module = prepare_module(module, device, strategy=NOOPStrategy())
86+
self.assertNotEqual(next(module.parameters()).device, device)
36287

363-
tc.assertTrue(isinstance(fsdp_module, FSDP))
88+
module2 = torch.nn.Linear(2, 2) # initialize on cpu
89+
module2 = prepare_module(module2, device, strategy="noop")
90+
self.assertNotEqual(next(module2.parameters()).device, device)
36491

36592
@skip_if_not_distributed
36693
def test_prepare_module_with_ddp(self) -> None:
@@ -443,29 +170,6 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No
443170
torch_compile_params=TorchCompileParams(backend="inductor"),
444171
)
445172

446-
@unittest.skipUnless(
447-
condition=COMPILE_AVAIL,
448-
reason="This test needs PyTorch 1.13 or greater to run.",
449-
)
450-
@skip_if_not_gpu
451-
def test_prepare_module_compile_module_state_dict(self) -> None:
452-
device = init_from_env()
453-
my_module = torch.nn.Linear(2, 2, device=device)
454-
my_module_state_dict = my_module.state_dict()
455-
self.assertIsNone(my_module._compiled_call_impl)
456-
compiled_module = prepare_module(
457-
module=my_module,
458-
device=device,
459-
torch_compile_params=TorchCompileParams(backend="inductor"),
460-
)
461-
compiled_state_dict = compiled_module.state_dict()
462-
self.assertCountEqual(compiled_state_dict.keys(), my_module_state_dict.keys())
463-
for k in compiled_state_dict.keys():
464-
self.assertTrue(
465-
torch.allclose(my_module_state_dict[k], compiled_state_dict[k])
466-
)
467-
self.assertIsNotNone(compiled_module._compiled_call_impl)
468-
469173
@unittest.skipUnless(
470174
condition=COMPILE_AVAIL,
471175
reason="This test needs PyTorch 1.13 or greater to run.",
@@ -494,3 +198,25 @@ def test_prepare_module_incompatible_FSDP_torchcompile_params(self) -> None:
494198
strategy=FSDPStrategy(use_orig_params=False),
495199
torch_compile_params=TorchCompileParams(),
496200
)
201+
202+
@unittest.skipUnless(
203+
condition=COMPILE_AVAIL,
204+
reason="This test needs PyTorch 1.13 or greater to run.",
205+
)
206+
def test_prepare_module_compile_module_state_dict(self) -> None:
207+
device = init_from_env()
208+
my_module = torch.nn.Linear(2, 2, device=device)
209+
my_module_state_dict = my_module.state_dict()
210+
self.assertIsNone(my_module._compiled_call_impl)
211+
compiled_module = prepare_module(
212+
module=my_module,
213+
device=device,
214+
torch_compile_params=TorchCompileParams(backend="inductor"),
215+
)
216+
compiled_state_dict = compiled_module.state_dict()
217+
self.assertCountEqual(compiled_state_dict.keys(), my_module_state_dict.keys())
218+
for k in compiled_state_dict.keys():
219+
self.assertTrue(
220+
torch.allclose(my_module_state_dict[k], compiled_state_dict[k])
221+
)
222+
self.assertIsNotNone(compiled_module._compiled_call_impl)

0 commit comments

Comments
 (0)