|
12 | 12 | from unittest.mock import MagicMock, patch
|
13 | 13 |
|
14 | 14 | import torch
|
15 |
| -from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler |
16 | 15 | from torchtnt.framework.auto_unit import TrainStepResults
|
17 |
| -from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu |
| 16 | +from torchtnt.utils.test_utils import skip_if_not_distributed |
18 | 17 |
|
19 | 18 | from torchtnt.utils.version import is_torch_version_geq_1_13
|
20 | 19 |
|
|
23 | 22 | COMPILE_AVAIL = True
|
24 | 23 | import torch._dynamo
|
25 | 24 |
|
26 |
| -from copy import deepcopy |
27 |
| - |
28 | 25 | from pyre_extensions import none_throws, ParameterSpecification as ParamSpec
|
29 | 26 |
|
30 | 27 | from torch.distributed import GradBucket
|
31 |
| -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
32 | 28 | from torchtnt.framework._test_utils import (
|
33 | 29 | DummyAutoUnit,
|
34 | 30 | generate_random_dataloader,
|
|
49 | 45 | from torchtnt.framework.unit import TPredictData
|
50 | 46 | from torchtnt.utils.device import copy_data_to_device
|
51 | 47 | from torchtnt.utils.distributed import spawn_multi_process
|
52 |
| -from torchtnt.utils.env import init_from_env, seed |
| 48 | +from torchtnt.utils.env import init_from_env |
53 | 49 | from torchtnt.utils.lr_scheduler import TLRScheduler
|
54 |
| -from torchtnt.utils.prepare_module import DDPStrategy, FSDPStrategy, TorchCompileParams |
| 50 | +from torchtnt.utils.prepare_module import DDPStrategy |
55 | 51 | from torchtnt.utils.progress import Progress
|
56 | 52 | from torchtnt.utils.timer import Timer
|
57 | 53 |
|
@@ -81,31 +77,6 @@ def test_app_state_mixin(self) -> None:
|
81 | 77 | for key in ("module", "optimizer", "lr_scheduler", "grad_scaler"):
|
82 | 78 | self.assertIn(key, auto_unit.app_state())
|
83 | 79 |
|
84 |
| - @skip_if_not_gpu |
85 |
| - @skip_if_not_distributed |
86 |
| - def test_fsdp_fp16(self) -> None: |
87 |
| - """ |
88 |
| - Test that FSDP + FP16 uses ShardedGradScaler |
89 |
| - """ |
90 |
| - spawn_multi_process( |
91 |
| - 2, |
92 |
| - "nccl", |
93 |
| - self._test_fsdp_fp16, |
94 |
| - ) |
95 |
| - |
96 |
| - @staticmethod |
97 |
| - def _test_fsdp_fp16() -> None: |
98 |
| - device = init_from_env() |
99 |
| - my_module = torch.nn.Linear(2, 2) |
100 |
| - auto_unit_fsdp = DummyAutoUnit( |
101 |
| - module=my_module, |
102 |
| - device=device, |
103 |
| - strategy=FSDPStrategy(), |
104 |
| - precision="fp16", |
105 |
| - ) |
106 |
| - tc = unittest.TestCase() |
107 |
| - tc.assertTrue(isinstance(auto_unit_fsdp.grad_scaler, ShardedGradScaler)) |
108 |
| - |
109 | 80 | def test_lr_scheduler_step(self) -> None:
|
110 | 81 | """
|
111 | 82 | Test that the lr scheduler is stepped every optimizer step when step_lr_interval="step"
|
@@ -150,49 +121,6 @@ def test_lr_scheduler_epoch(self) -> None:
|
150 | 121 | train(auto_unit, train_dataloader=train_dl, max_epochs=max_epochs)
|
151 | 122 | self.assertEqual(auto_unit.lr_scheduler.step.call_count, max_epochs)
|
152 | 123 |
|
153 |
| - @skip_if_not_gpu |
154 |
| - @patch("torch.autocast") |
155 |
| - def test_mixed_precision_fp16(self, mock_autocast: MagicMock) -> None: |
156 |
| - """ |
157 |
| - Test that the mixed precision autocast context is called when fp16 precision is set |
158 |
| - """ |
159 |
| - my_module = torch.nn.Linear(2, 2) |
160 |
| - auto_unit = DummyAutoUnit( |
161 |
| - module=my_module, |
162 |
| - precision="fp16", |
163 |
| - ) |
164 |
| - dummy_iterable = [(torch.ones(2, 2), torch.ones(2, 2))] |
165 |
| - state = get_dummy_train_state(dummy_iterable) |
166 |
| - auto_unit.train_step( |
167 |
| - state=state, |
168 |
| - data=auto_unit.get_next_train_batch(state, iter(dummy_iterable)), |
169 |
| - ) |
170 |
| - mock_autocast.assert_called_with( |
171 |
| - device_type="cuda", dtype=torch.float16, enabled=True |
172 |
| - ) |
173 |
| - |
174 |
| - @skip_if_not_gpu |
175 |
| - @patch("torch.autocast") |
176 |
| - def test_mixed_precision_bf16(self, mock_autocast: MagicMock) -> None: |
177 |
| - """ |
178 |
| - Test that the mixed precision autocast context is called when bf16 precision is set |
179 |
| - """ |
180 |
| - my_module = torch.nn.Linear(2, 2) |
181 |
| - |
182 |
| - auto_unit = DummyAutoUnit( |
183 |
| - module=my_module, |
184 |
| - precision="bf16", |
185 |
| - ) |
186 |
| - dummy_iterable = [(torch.ones(2, 2), torch.ones(2, 2))] |
187 |
| - state = get_dummy_train_state(dummy_iterable) |
188 |
| - auto_unit.train_step( |
189 |
| - state=state, |
190 |
| - data=auto_unit.get_next_train_batch(state, iter(dummy_iterable)), |
191 |
| - ) |
192 |
| - mock_autocast.assert_called_with( |
193 |
| - device_type="cuda", dtype=torch.bfloat16, enabled=True |
194 |
| - ) |
195 |
| - |
196 | 124 | def test_mixed_precision_invalid_str(self) -> None:
|
197 | 125 | """
|
198 | 126 | Test that an exception is raised with an invalid precision string
|
@@ -310,191 +238,6 @@ def test_stochastic_weight_averaging_update_freq(self) -> None:
|
310 | 238 | # 1 warmup + epoch 2 + epoch 3 = 2
|
311 | 239 | self.assertEqual(update_swa_mock.call_count, 2)
|
312 | 240 |
|
313 |
| - @skip_if_not_distributed |
314 |
| - @skip_if_not_gpu |
315 |
| - def test_stochastic_weight_averaging_fsdp(self) -> None: |
316 |
| - """ |
317 |
| - Test that swa params with FSDP is identical to non-FSDP swa |
318 |
| - """ |
319 |
| - spawn_multi_process( |
320 |
| - 2, |
321 |
| - "nccl", |
322 |
| - self._test_stochastic_weight_averaging_fsdp, |
323 |
| - ) |
324 |
| - |
325 |
| - @staticmethod |
326 |
| - def _test_stochastic_weight_averaging_fsdp() -> None: |
327 |
| - class Net(torch.nn.Module): |
328 |
| - def __init__(self): |
329 |
| - super(Net, self).__init__() |
330 |
| - self.l1 = torch.nn.Linear(2, 2) |
331 |
| - self.b1 = torch.nn.BatchNorm1d(2) |
332 |
| - self.l2 = torch.nn.Linear(2, 2) |
333 |
| - |
334 |
| - def forward(self, x): |
335 |
| - x = self.l1(x) |
336 |
| - x = self.b1(x) |
337 |
| - x = self.l2(x) |
338 |
| - return x |
339 |
| - |
340 |
| - # so all ranks start with same initialized weights |
341 |
| - seed(0) |
342 |
| - device = init_from_env() |
343 |
| - my_module = Net() |
344 |
| - |
345 |
| - auto_unit = DummyAutoUnit( |
346 |
| - module=deepcopy(my_module), |
347 |
| - device=device, |
348 |
| - step_lr_interval="step", |
349 |
| - swa_params=SWAParams( |
350 |
| - warmup_steps_or_epochs=1, |
351 |
| - step_or_epoch_update_freq=1, |
352 |
| - swalr_params=SWALRParams( |
353 |
| - anneal_steps_or_epochs=3, |
354 |
| - ), |
355 |
| - averaging_method="ema", |
356 |
| - ), |
357 |
| - ) |
358 |
| - |
359 |
| - auto_unit_fsdp = DummyAutoUnit( |
360 |
| - module=my_module, |
361 |
| - device=device, |
362 |
| - step_lr_interval="step", |
363 |
| - strategy=FSDPStrategy(), |
364 |
| - swa_params=SWAParams( |
365 |
| - warmup_steps_or_epochs=1, |
366 |
| - step_or_epoch_update_freq=1, |
367 |
| - swalr_params=SWALRParams( |
368 |
| - anneal_steps_or_epochs=3, |
369 |
| - ), |
370 |
| - averaging_method="ema", |
371 |
| - ), |
372 |
| - ) |
373 |
| - |
374 |
| - input_dim = 2 |
375 |
| - dataset_len = 10 |
376 |
| - batch_size = 2 |
377 |
| - |
378 |
| - dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) |
379 |
| - train(auto_unit, dataloader, max_epochs=1, max_steps_per_epoch=5) |
380 |
| - train(auto_unit_fsdp, dataloader, max_epochs=1, max_steps_per_epoch=5) |
381 |
| - |
382 |
| - swa_params = list(auto_unit.swa_model.module.parameters()) |
383 |
| - with FSDP.summon_full_params(auto_unit_fsdp.swa_model): |
384 |
| - swa_fsdp_params = list(auto_unit_fsdp.swa_model.module.parameters()) |
385 |
| - |
386 |
| - # Iterate and compare each parameter |
387 |
| - for p1, p2 in zip(swa_params, swa_fsdp_params, strict=True): |
388 |
| - torch.testing.assert_close(p2, p1, check_device=False) |
389 |
| - |
390 |
| - @skip_if_not_gpu |
391 |
| - @patch("torch.autocast") |
392 |
| - def test_eval_mixed_precision_bf16(self, mock_autocast: MagicMock) -> None: |
393 |
| - """ |
394 |
| - Test that the mixed precision autocast context is called during evaluate when precision = bf16 |
395 |
| - """ |
396 |
| - my_module = torch.nn.Linear(2, 2) |
397 |
| - auto_unit = DummyAutoUnit( |
398 |
| - module=my_module, |
399 |
| - precision="bf16", |
400 |
| - ) |
401 |
| - |
402 |
| - input_dim = 2 |
403 |
| - dataset_len = 8 |
404 |
| - batch_size = 2 |
405 |
| - |
406 |
| - eval_dl = generate_random_dataloader(dataset_len, input_dim, batch_size) |
407 |
| - evaluate(auto_unit, eval_dl) |
408 |
| - mock_autocast.assert_called_with( |
409 |
| - device_type="cuda", dtype=torch.bfloat16, enabled=True |
410 |
| - ) |
411 |
| - |
412 |
| - @skip_if_not_gpu |
413 |
| - @skip_if_not_distributed |
414 |
| - def test_no_sync(self) -> None: |
415 |
| - """ |
416 |
| - Test that the no_sync autocast context is correctly applied when using gradient accumulation |
417 |
| - """ |
418 |
| - spawn_multi_process( |
419 |
| - 2, |
420 |
| - "nccl", |
421 |
| - self._test_ddp_no_sync, |
422 |
| - ) |
423 |
| - spawn_multi_process( |
424 |
| - 2, |
425 |
| - "nccl", |
426 |
| - self._test_fsdp_no_sync, |
427 |
| - ) |
428 |
| - |
429 |
| - @staticmethod |
430 |
| - def _test_ddp_no_sync() -> None: |
431 |
| - """ |
432 |
| - Test that the no_sync autocast context is correctly applied when using gradient accumulation and DDP |
433 |
| - """ |
434 |
| - |
435 |
| - my_module = torch.nn.Linear(2, 2) |
436 |
| - |
437 |
| - auto_unit = DummyAutoUnit( |
438 |
| - module=my_module, |
439 |
| - strategy=DDPStrategy(), |
440 |
| - gradient_accumulation_steps=2, |
441 |
| - ) |
442 |
| - |
443 |
| - dummy_iterator = iter( |
444 |
| - [(torch.ones(2, 2), torch.ones(2, 2)), (torch.ones(2, 2), torch.ones(2, 2))] |
445 |
| - ) |
446 |
| - state = get_dummy_train_state() |
447 |
| - |
448 |
| - # for the first step no_sync should be called since we accumulate gradients |
449 |
| - with patch.object(auto_unit.module, "no_sync") as no_sync_mock: |
450 |
| - auto_unit.train_step( |
451 |
| - state=state, data=auto_unit.get_next_train_batch(state, dummy_iterator) |
452 |
| - ) |
453 |
| - no_sync_mock.assert_called_once() |
454 |
| - |
455 |
| - auto_unit.train_progress.increment_step() |
456 |
| - # for the second step no_sync should not be called since we run optimizer step |
457 |
| - with patch.object(auto_unit.module, "no_sync") as no_sync_mock: |
458 |
| - auto_unit.train_step( |
459 |
| - state=state, data=auto_unit.get_next_train_batch(state, dummy_iterator) |
460 |
| - ) |
461 |
| - no_sync_mock.assert_not_called() |
462 |
| - |
463 |
| - @staticmethod |
464 |
| - def _test_fsdp_no_sync() -> None: |
465 |
| - """ |
466 |
| - Test that the no_sync autocast context is correctly applied when using gradient accumulation and FSDP |
467 |
| - """ |
468 |
| - device = init_from_env() |
469 |
| - my_module = torch.nn.Linear(2, 2).to(device) |
470 |
| - |
471 |
| - auto_unit = DummyAutoUnit( |
472 |
| - module=my_module, |
473 |
| - device=device, |
474 |
| - strategy=FSDPStrategy(), |
475 |
| - gradient_accumulation_steps=2, |
476 |
| - ) |
477 |
| - |
478 |
| - dummy_iterator = iter( |
479 |
| - [(torch.ones(2, 2), torch.ones(2, 2)), (torch.ones(2, 2), torch.ones(2, 2))] |
480 |
| - ) |
481 |
| - state = get_dummy_train_state() |
482 |
| - |
483 |
| - # for the first step no_sync should be called since we accumulate gradients |
484 |
| - with patch.object(auto_unit.module, "no_sync") as no_sync_mock: |
485 |
| - auto_unit.train_step( |
486 |
| - state=state, data=auto_unit.get_next_train_batch(state, dummy_iterator) |
487 |
| - ) |
488 |
| - no_sync_mock.assert_called_once() |
489 |
| - |
490 |
| - auto_unit.train_progress.increment_step() |
491 |
| - # for the second step no_sync should not be called since we run optimizer step |
492 |
| - with patch.object(auto_unit.module, "no_sync") as no_sync_mock: |
493 |
| - auto_unit.train_step( |
494 |
| - state=state, data=auto_unit.get_next_train_batch(state, dummy_iterator) |
495 |
| - ) |
496 |
| - no_sync_mock.assert_not_called() |
497 |
| - |
498 | 241 | def test_move_data_to_device(self) -> None:
|
499 | 242 | """
|
500 | 243 | Test that move_data_to_device is called
|
@@ -746,53 +489,6 @@ def test_auto_unit_timing_predict(self) -> None:
|
746 | 489 | timer=Timer(),
|
747 | 490 | )
|
748 | 491 |
|
749 |
| - @skip_if_not_gpu |
750 |
| - @patch("torch.autocast") |
751 |
| - def test_predict_mixed_precision_fp16(self, mock_autocast: MagicMock) -> None: |
752 |
| - """ |
753 |
| - Test that the mixed precision autocast context is called during predict when precision = fp16 |
754 |
| - """ |
755 |
| - my_module = torch.nn.Linear(2, 2) |
756 |
| - auto_unit = AutoPredictUnit(module=my_module, precision="fp16") |
757 |
| - |
758 |
| - input_dim = 2 |
759 |
| - dataset_len = 8 |
760 |
| - batch_size = 2 |
761 |
| - |
762 |
| - predict_dl = generate_random_iterable_dataloader( |
763 |
| - dataset_len, input_dim, batch_size |
764 |
| - ) |
765 |
| - predict(auto_unit, predict_dl) |
766 |
| - mock_autocast.assert_called_with( |
767 |
| - device_type="cuda", dtype=torch.float16, enabled=True |
768 |
| - ) |
769 |
| - |
770 |
| - @unittest.skipUnless( |
771 |
| - condition=COMPILE_AVAIL, |
772 |
| - reason="This test needs PyTorch 1.13 or greater to run.", |
773 |
| - ) |
774 |
| - @skip_if_not_gpu |
775 |
| - @patch("torch.compile") |
776 |
| - def test_compile_predict(self, mock_dynamo: MagicMock) -> None: |
777 |
| - """ |
778 |
| - e2e torch compile on predict |
779 |
| - """ |
780 |
| - my_module = torch.nn.Linear(2, 2) |
781 |
| - auto_unit = AutoPredictUnit( |
782 |
| - module=my_module, |
783 |
| - torch_compile_params=TorchCompileParams(backend="eager"), |
784 |
| - ) |
785 |
| - |
786 |
| - input_dim = 2 |
787 |
| - dataset_len = 8 |
788 |
| - batch_size = 2 |
789 |
| - |
790 |
| - predict_dl = generate_random_iterable_dataloader( |
791 |
| - dataset_len, input_dim, batch_size |
792 |
| - ) |
793 |
| - predict(auto_unit, predict_dl) |
794 |
| - mock_dynamo.assert_called() |
795 |
| - |
796 | 492 | def test_auto_predict_unit_timing_predict(self) -> None:
|
797 | 493 | """
|
798 | 494 | Test auto timing in AutoUnit for predict
|
|
0 commit comments