|
11 | 11 | from unittest.mock import patch
|
12 | 12 |
|
13 | 13 | import torch
|
14 |
| -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
15 |
| -from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision |
16 | 14 | from torch.nn.parallel import DistributedDataParallel as DDP
|
17 | 15 | from torchtnt.utils.distributed import spawn_multi_process
|
18 | 16 | from torchtnt.utils.env import init_from_env
|
19 | 17 | from torchtnt.utils.prepare_module import (
|
20 |
| - _is_fsdp_module, |
21 | 18 | DDPStrategy,
|
22 | 19 | FSDPStrategy,
|
23 | 20 | NOOPStrategy,
|
24 |
| - prepare_ddp, |
25 |
| - prepare_fsdp, |
26 | 21 | prepare_module,
|
27 | 22 | TorchCompileParams,
|
28 | 23 | )
|
29 |
| -from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu |
30 |
| -from torchtnt.utils.version import ( |
31 |
| - is_torch_version_geq_1_13, |
32 |
| - is_torch_version_geq_2_0, |
33 |
| - Version, |
34 |
| -) |
| 24 | +from torchtnt.utils.test_utils import skip_if_not_distributed |
| 25 | +from torchtnt.utils.version import is_torch_version_geq_1_13, Version |
35 | 26 |
|
36 | 27 | COMPILE_AVAIL = False
|
37 | 28 | if is_torch_version_geq_1_13():
|
38 | 29 | COMPILE_AVAIL = True
|
39 | 30 | import torch._dynamo
|
40 | 31 |
|
41 |
| -if is_torch_version_geq_2_0(): |
42 |
| - from torch.distributed._composable import fully_shard |
43 |
| - |
44 | 32 |
|
45 | 33 | class PrepareModelTest(unittest.TestCase):
|
46 |
| - @skip_if_not_gpu |
47 |
| - def test_prepare_no_strategy(self) -> None: |
48 |
| - module = torch.nn.Linear(2, 2) # initialize on cpu |
49 |
| - device = init_from_env() # should be cuda device |
50 |
| - module = prepare_module(module, device, strategy=None) |
51 |
| - self.assertEqual(next(module.parameters()).device, device) |
52 |
| - |
53 |
| - @skip_if_not_gpu |
54 |
| - def test_prepare_noop(self) -> None: |
55 |
| - module = torch.nn.Linear(2, 2) # initialize on cpu |
56 |
| - device = init_from_env() # should be cuda device |
57 |
| - module = prepare_module(module, device, strategy=NOOPStrategy()) |
58 |
| - self.assertNotEqual(next(module.parameters()).device, device) |
59 |
| - |
60 |
| - module2 = torch.nn.Linear(2, 2) # initialize on cpu |
61 |
| - module2 = prepare_module(module2, device, strategy="noop") |
62 |
| - self.assertNotEqual(next(module2.parameters()).device, device) |
63 |
| - |
64 |
| - @skip_if_not_gpu |
65 |
| - @skip_if_not_distributed |
66 |
| - def test_prepare_ddp(self) -> None: |
67 |
| - spawn_multi_process( |
68 |
| - 2, |
69 |
| - "nccl", |
70 |
| - self._test_prepare_ddp, |
71 |
| - ) |
72 |
| - |
73 |
| - @staticmethod |
74 |
| - def _test_prepare_ddp() -> None: |
75 |
| - module = torch.nn.Linear(2, 2) |
76 |
| - device = init_from_env() |
77 |
| - ddp_module = prepare_ddp( |
78 |
| - module, |
79 |
| - device, |
80 |
| - DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True), |
81 |
| - ) |
82 |
| - tc = unittest.TestCase() |
83 |
| - tc.assertTrue(isinstance(ddp_module, DDP)) |
84 |
| - |
85 |
| - @skip_if_not_gpu |
86 |
| - @skip_if_not_distributed |
87 |
| - def test_prepare_fsdp(self) -> None: |
88 |
| - spawn_multi_process( |
89 |
| - 2, |
90 |
| - "nccl", |
91 |
| - self._test_prepare_fsdp, |
92 |
| - ) |
93 |
| - |
94 |
| - @staticmethod |
95 |
| - def _test_prepare_fsdp() -> None: |
96 |
| - module = torch.nn.Linear(2, 2) |
97 |
| - device = init_from_env() |
98 |
| - fsdp_module = prepare_fsdp(module, device, FSDPStrategy(limit_all_gathers=True)) |
99 |
| - tc = unittest.TestCase() |
100 |
| - tc.assertTrue(isinstance(fsdp_module, FSDP)) |
101 |
| - |
102 |
| - @skip_if_not_distributed |
103 |
| - @skip_if_not_gpu |
104 |
| - def test_fsdp_pytorch_version(self) -> None: |
105 |
| - """ |
106 |
| - Test that a RuntimeError is thrown when using FSDP, and PyTorch < v1.12 |
107 |
| - """ |
108 |
| - spawn_multi_process( |
109 |
| - 2, |
110 |
| - "nccl", |
111 |
| - self._test_fsdp_pytorch_version, |
112 |
| - ) |
113 |
| - |
114 |
| - @staticmethod |
115 |
| - def _test_fsdp_pytorch_version() -> None: |
116 |
| - device = init_from_env() |
117 |
| - module = torch.nn.Linear(2, 2).to(device) |
118 |
| - |
119 |
| - tc = unittest.TestCase() |
120 |
| - with patch( |
121 |
| - "torchtnt.utils.prepare_module.is_torch_version_geq_1_12", |
122 |
| - return_value=False, |
123 |
| - ), tc.assertRaisesRegex( |
124 |
| - RuntimeError, |
125 |
| - "Please install PyTorch 1.12 or higher to use FSDP: https://pytorch.org/get-started/locally/", |
126 |
| - ): |
127 |
| - _ = prepare_fsdp(module, device, FSDPStrategy()) |
128 |
| - |
129 |
| - @staticmethod |
130 |
| - def _test_is_fsdp_module() -> None: |
131 |
| - model = torch.nn.Linear(1, 1) |
132 |
| - assert not _is_fsdp_module(model) |
133 |
| - model = FSDP(torch.nn.Linear(1, 1)) |
134 |
| - assert _is_fsdp_module(model) |
135 |
| - model = torch.nn.Linear(1, 1) |
136 |
| - if is_torch_version_geq_2_0(): |
137 |
| - fully_shard(model) |
138 |
| - assert _is_fsdp_module(model) |
139 |
| - |
140 |
| - @skip_if_not_distributed |
141 |
| - @unittest.skipUnless( |
142 |
| - condition=bool(torch.cuda.device_count() >= 2), |
143 |
| - reason="This test needs 2 GPUs to run.", |
144 |
| - ) |
145 |
| - def test_is_fsdp_module(self) -> None: |
146 |
| - spawn_multi_process( |
147 |
| - 2, |
148 |
| - "gloo", |
149 |
| - self._test_is_fsdp_module, |
150 |
| - ) |
151 |
| - |
152 |
| - @skip_if_not_distributed |
153 |
| - @skip_if_not_gpu |
154 |
| - def test_fdsp_precision(self) -> None: |
155 |
| - spawn_multi_process( |
156 |
| - 2, |
157 |
| - "nccl", |
158 |
| - self._test_fdsp_precision, |
159 |
| - ) |
160 |
| - |
161 |
| - @staticmethod |
162 |
| - def _test_fdsp_precision() -> None: |
163 |
| - module = torch.nn.Linear(1, 1) |
164 |
| - device = init_from_env() |
165 |
| - mixed_precision = MixedPrecision( |
166 |
| - param_dtype=torch.float64, |
167 |
| - ) |
168 |
| - fsdp_module = prepare_fsdp( |
169 |
| - module, device, FSDPStrategy(mixed_precision=mixed_precision) |
170 |
| - ) |
171 |
| - tc = unittest.TestCase() |
172 |
| - tc.assertTrue(isinstance(fsdp_module, FSDP)) |
173 |
| - tc.assertEqual( |
174 |
| - fsdp_module.mixed_precision.param_dtype, mixed_precision.param_dtype |
175 |
| - ) |
176 |
| - |
177 |
| - @skip_if_not_distributed |
178 |
| - @skip_if_not_gpu |
179 |
| - @unittest.skipUnless( |
180 |
| - condition=bool(torch.cuda.device_count() >= 2), |
181 |
| - reason="This test needs 2 GPUs to run.", |
182 |
| - ) |
183 |
| - def test_fdsp_str_types(self) -> None: |
184 |
| - spawn_multi_process( |
185 |
| - 2, |
186 |
| - "nccl", |
187 |
| - self._test_fdsp_precision_str_types, |
188 |
| - ) |
189 |
| - spawn_multi_process( |
190 |
| - 2, |
191 |
| - "nccl", |
192 |
| - self._test_fdsp_backward_prefetch_str_types, |
193 |
| - ) |
194 |
| - spawn_multi_process( |
195 |
| - 2, |
196 |
| - "nccl", |
197 |
| - self._test_fdsp_sharding_strategy_str_types, |
198 |
| - ) |
199 |
| - spawn_multi_process( |
200 |
| - 2, |
201 |
| - "nccl", |
202 |
| - self._test_fdsp_state_dict_str_types, |
203 |
| - ) |
204 |
| - |
205 |
| - @staticmethod |
206 |
| - def _test_fdsp_precision_str_types() -> None: |
207 |
| - from torchtnt.utils.prepare_module import MixedPrecision as _MixedPrecision |
208 |
| - |
209 |
| - module = torch.nn.Linear(1, 1) |
210 |
| - device = init_from_env() |
211 |
| - mixed_precision = _MixedPrecision( |
212 |
| - param_dtype="fp16", |
213 |
| - reduce_dtype="bf16", |
214 |
| - buffer_dtype="fp32", |
215 |
| - ) |
216 |
| - |
217 |
| - fsdp_module = prepare_fsdp( |
218 |
| - module, device, FSDPStrategy(mixed_precision=mixed_precision) |
219 |
| - ) |
220 |
| - tc = unittest.TestCase() |
221 |
| - tc.assertTrue(isinstance(fsdp_module, FSDP)) |
222 |
| - |
223 |
| - @staticmethod |
224 |
| - def _test_fdsp_backward_prefetch_str_types() -> None: |
225 |
| - module = torch.nn.Linear(1, 1) |
226 |
| - device = init_from_env() |
227 |
| - |
228 |
| - tc = unittest.TestCase() |
229 |
| - for value in ["BACKWARD_PRE", "BACKWARD_POST"]: |
230 |
| - fsdp_module = prepare_fsdp( |
231 |
| - module, device, FSDPStrategy(backward_prefetch=value) |
232 |
| - ) |
233 |
| - tc.assertTrue(isinstance(fsdp_module, FSDP), f"tested value: {value}") |
234 |
| - |
235 |
| - @staticmethod |
236 |
| - def _test_fdsp_sharding_strategy_str_types() -> None: |
237 |
| - module = torch.nn.Linear(1, 1) |
238 |
| - device = init_from_env() |
239 |
| - |
240 |
| - tc = unittest.TestCase() |
241 |
| - for value in [ |
242 |
| - "FULL_SHARD", |
243 |
| - "SHARD_GRAD_OP", |
244 |
| - "NO_SHARD", |
245 |
| - # skip hybrid strategy; tricky to configure in-test |
246 |
| - ]: |
247 |
| - |
248 |
| - fsdp_module = prepare_fsdp( |
249 |
| - module, |
250 |
| - device, |
251 |
| - FSDPStrategy(sharding_strategy=value), |
252 |
| - ) |
253 |
| - tc.assertTrue(isinstance(fsdp_module, FSDP), f"tested value: {value}") |
254 |
| - |
255 |
| - @staticmethod |
256 |
| - def _test_fdsp_state_dict_str_types() -> None: |
257 |
| - module = torch.nn.Linear(1, 1) |
258 |
| - device = init_from_env() |
259 |
| - |
260 |
| - tc = unittest.TestCase() |
261 |
| - for value in [ |
262 |
| - "FULL_STATE_DICT", |
263 |
| - "LOCAL_STATE_DICT", |
264 |
| - "SHARDED_STATE_DICT", |
265 |
| - ]: |
266 |
| - fsdp_module = prepare_fsdp( |
267 |
| - module, device, FSDPStrategy(state_dict_type=value) |
268 |
| - ) |
269 |
| - tc.assertTrue(isinstance(fsdp_module, FSDP), f"tested value: {value}") |
270 |
| - |
271 | 34 | def test_invalid_fsdp_strategy_str_values(self) -> None:
|
272 | 35 | from torchtnt.utils.prepare_module import MixedPrecision as _MixedPrecision
|
273 | 36 |
|
@@ -315,52 +78,16 @@ def test_prepare_module_strategy_invalid_str(self) -> None:
|
315 | 78 | strategy="foo",
|
316 | 79 | )
|
317 | 80 |
|
318 |
| - @skip_if_not_distributed |
319 |
| - @skip_if_not_gpu |
320 |
| - def test_prepare_module_with_fsdp(self) -> None: |
321 |
| - """ |
322 |
| - Launch tests of FSDP strategy |
323 |
| - """ |
324 |
| - spawn_multi_process( |
325 |
| - 2, |
326 |
| - "nccl", |
327 |
| - self._test_prepare_module_fsdp_strategy_wrapped_in_fsdp, |
328 |
| - ) |
329 |
| - spawn_multi_process( |
330 |
| - 2, |
331 |
| - "nccl", |
332 |
| - self._test_prepare_module_fsdp_string_wrapped_in_fsdp, |
333 |
| - ) |
334 |
| - |
335 |
| - @staticmethod |
336 |
| - def _test_prepare_module_fsdp_strategy_wrapped_in_fsdp() -> None: |
337 |
| - """ |
338 |
| - Test that the module is correctly wrapped in FSDP |
339 |
| - """ |
340 |
| - |
341 |
| - fsdp_module = prepare_module( |
342 |
| - module=torch.nn.Linear(2, 2), |
343 |
| - device=init_from_env(), |
344 |
| - strategy=FSDPStrategy(), |
345 |
| - ) |
346 |
| - tc = unittest.TestCase() |
347 |
| - |
348 |
| - tc.assertTrue(isinstance(fsdp_module, FSDP)) |
349 |
| - |
350 |
| - @staticmethod |
351 |
| - def _test_prepare_module_fsdp_string_wrapped_in_fsdp() -> None: |
352 |
| - """ |
353 |
| - Test that the module is correctly wrapped in FSDP when passing "fsdp" as a string |
354 |
| - """ |
| 81 | + def test_prepare_noop(self) -> None: |
| 82 | + device = torch.device("cuda") # Suppose init_from_env returns cuda |
355 | 83 |
|
356 |
| - fsdp_module = prepare_module( |
357 |
| - module=torch.nn.Linear(2, 2), |
358 |
| - device=init_from_env(), |
359 |
| - strategy="fsdp", |
360 |
| - ) |
361 |
| - tc = unittest.TestCase() |
| 84 | + module = torch.nn.Linear(2, 2) # initialize on cpu |
| 85 | + module = prepare_module(module, device, strategy=NOOPStrategy()) |
| 86 | + self.assertNotEqual(next(module.parameters()).device, device) |
362 | 87 |
|
363 |
| - tc.assertTrue(isinstance(fsdp_module, FSDP)) |
| 88 | + module2 = torch.nn.Linear(2, 2) # initialize on cpu |
| 89 | + module2 = prepare_module(module2, device, strategy="noop") |
| 90 | + self.assertNotEqual(next(module2.parameters()).device, device) |
364 | 91 |
|
365 | 92 | @skip_if_not_distributed
|
366 | 93 | def test_prepare_module_with_ddp(self) -> None:
|
@@ -443,29 +170,6 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No
|
443 | 170 | torch_compile_params=TorchCompileParams(backend="inductor"),
|
444 | 171 | )
|
445 | 172 |
|
446 |
| - @unittest.skipUnless( |
447 |
| - condition=COMPILE_AVAIL, |
448 |
| - reason="This test needs PyTorch 1.13 or greater to run.", |
449 |
| - ) |
450 |
| - @skip_if_not_gpu |
451 |
| - def test_prepare_module_compile_module_state_dict(self) -> None: |
452 |
| - device = init_from_env() |
453 |
| - my_module = torch.nn.Linear(2, 2, device=device) |
454 |
| - my_module_state_dict = my_module.state_dict() |
455 |
| - self.assertIsNone(my_module._compiled_call_impl) |
456 |
| - compiled_module = prepare_module( |
457 |
| - module=my_module, |
458 |
| - device=device, |
459 |
| - torch_compile_params=TorchCompileParams(backend="inductor"), |
460 |
| - ) |
461 |
| - compiled_state_dict = compiled_module.state_dict() |
462 |
| - self.assertCountEqual(compiled_state_dict.keys(), my_module_state_dict.keys()) |
463 |
| - for k in compiled_state_dict.keys(): |
464 |
| - self.assertTrue( |
465 |
| - torch.allclose(my_module_state_dict[k], compiled_state_dict[k]) |
466 |
| - ) |
467 |
| - self.assertIsNotNone(compiled_module._compiled_call_impl) |
468 |
| - |
469 | 173 | @unittest.skipUnless(
|
470 | 174 | condition=COMPILE_AVAIL,
|
471 | 175 | reason="This test needs PyTorch 1.13 or greater to run.",
|
@@ -494,3 +198,25 @@ def test_prepare_module_incompatible_FSDP_torchcompile_params(self) -> None:
|
494 | 198 | strategy=FSDPStrategy(use_orig_params=False),
|
495 | 199 | torch_compile_params=TorchCompileParams(),
|
496 | 200 | )
|
| 201 | + |
| 202 | + @unittest.skipUnless( |
| 203 | + condition=COMPILE_AVAIL, |
| 204 | + reason="This test needs PyTorch 1.13 or greater to run.", |
| 205 | + ) |
| 206 | + def test_prepare_module_compile_module_state_dict(self) -> None: |
| 207 | + device = init_from_env() |
| 208 | + my_module = torch.nn.Linear(2, 2, device=device) |
| 209 | + my_module_state_dict = my_module.state_dict() |
| 210 | + self.assertIsNone(my_module._compiled_call_impl) |
| 211 | + compiled_module = prepare_module( |
| 212 | + module=my_module, |
| 213 | + device=device, |
| 214 | + torch_compile_params=TorchCompileParams(backend="inductor"), |
| 215 | + ) |
| 216 | + compiled_state_dict = compiled_module.state_dict() |
| 217 | + self.assertCountEqual(compiled_state_dict.keys(), my_module_state_dict.keys()) |
| 218 | + for k in compiled_state_dict.keys(): |
| 219 | + self.assertTrue( |
| 220 | + torch.allclose(my_module_state_dict[k], compiled_state_dict[k]) |
| 221 | + ) |
| 222 | + self.assertIsNotNone(compiled_module._compiled_call_impl) |
0 commit comments