Skip to content

Commit e806af0

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Move auto_unit GPU tests to dedicated file (#755)
Summary: Pull Request resolved: #755 Reviewed By: galrotem Differential Revision: D55224519 fbshipit-source-id: ec179af6db2303ae4e32fb3d1568f7754aa54d90
1 parent bc2bf15 commit e806af0

File tree

2 files changed

+350
-307
lines changed

2 files changed

+350
-307
lines changed

tests/framework/test_auto_unit.py

Lines changed: 3 additions & 307 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
from unittest.mock import MagicMock, patch
1313

1414
import torch
15-
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
1615
from torchtnt.framework.auto_unit import TrainStepResults
17-
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
16+
from torchtnt.utils.test_utils import skip_if_not_distributed
1817

1918
from torchtnt.utils.version import is_torch_version_geq_1_13
2019

@@ -23,12 +22,9 @@
2322
COMPILE_AVAIL = True
2423
import torch._dynamo
2524

26-
from copy import deepcopy
27-
2825
from pyre_extensions import none_throws, ParameterSpecification as ParamSpec
2926

3027
from torch.distributed import GradBucket
31-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
3228
from torchtnt.framework._test_utils import (
3329
DummyAutoUnit,
3430
generate_random_dataloader,
@@ -49,9 +45,9 @@
4945
from torchtnt.framework.unit import TPredictData
5046
from torchtnt.utils.device import copy_data_to_device
5147
from torchtnt.utils.distributed import spawn_multi_process
52-
from torchtnt.utils.env import init_from_env, seed
48+
from torchtnt.utils.env import init_from_env
5349
from torchtnt.utils.lr_scheduler import TLRScheduler
54-
from torchtnt.utils.prepare_module import DDPStrategy, FSDPStrategy, TorchCompileParams
50+
from torchtnt.utils.prepare_module import DDPStrategy
5551
from torchtnt.utils.progress import Progress
5652
from torchtnt.utils.timer import Timer
5753

@@ -81,31 +77,6 @@ def test_app_state_mixin(self) -> None:
8177
for key in ("module", "optimizer", "lr_scheduler", "grad_scaler"):
8278
self.assertIn(key, auto_unit.app_state())
8379

84-
@skip_if_not_gpu
85-
@skip_if_not_distributed
86-
def test_fsdp_fp16(self) -> None:
87-
"""
88-
Test that FSDP + FP16 uses ShardedGradScaler
89-
"""
90-
spawn_multi_process(
91-
2,
92-
"nccl",
93-
self._test_fsdp_fp16,
94-
)
95-
96-
@staticmethod
97-
def _test_fsdp_fp16() -> None:
98-
device = init_from_env()
99-
my_module = torch.nn.Linear(2, 2)
100-
auto_unit_fsdp = DummyAutoUnit(
101-
module=my_module,
102-
device=device,
103-
strategy=FSDPStrategy(),
104-
precision="fp16",
105-
)
106-
tc = unittest.TestCase()
107-
tc.assertTrue(isinstance(auto_unit_fsdp.grad_scaler, ShardedGradScaler))
108-
10980
def test_lr_scheduler_step(self) -> None:
11081
"""
11182
Test that the lr scheduler is stepped every optimizer step when step_lr_interval="step"
@@ -150,49 +121,6 @@ def test_lr_scheduler_epoch(self) -> None:
150121
train(auto_unit, train_dataloader=train_dl, max_epochs=max_epochs)
151122
self.assertEqual(auto_unit.lr_scheduler.step.call_count, max_epochs)
152123

153-
@skip_if_not_gpu
154-
@patch("torch.autocast")
155-
def test_mixed_precision_fp16(self, mock_autocast: MagicMock) -> None:
156-
"""
157-
Test that the mixed precision autocast context is called when fp16 precision is set
158-
"""
159-
my_module = torch.nn.Linear(2, 2)
160-
auto_unit = DummyAutoUnit(
161-
module=my_module,
162-
precision="fp16",
163-
)
164-
dummy_iterable = [(torch.ones(2, 2), torch.ones(2, 2))]
165-
state = get_dummy_train_state(dummy_iterable)
166-
auto_unit.train_step(
167-
state=state,
168-
data=auto_unit.get_next_train_batch(state, iter(dummy_iterable)),
169-
)
170-
mock_autocast.assert_called_with(
171-
device_type="cuda", dtype=torch.float16, enabled=True
172-
)
173-
174-
@skip_if_not_gpu
175-
@patch("torch.autocast")
176-
def test_mixed_precision_bf16(self, mock_autocast: MagicMock) -> None:
177-
"""
178-
Test that the mixed precision autocast context is called when bf16 precision is set
179-
"""
180-
my_module = torch.nn.Linear(2, 2)
181-
182-
auto_unit = DummyAutoUnit(
183-
module=my_module,
184-
precision="bf16",
185-
)
186-
dummy_iterable = [(torch.ones(2, 2), torch.ones(2, 2))]
187-
state = get_dummy_train_state(dummy_iterable)
188-
auto_unit.train_step(
189-
state=state,
190-
data=auto_unit.get_next_train_batch(state, iter(dummy_iterable)),
191-
)
192-
mock_autocast.assert_called_with(
193-
device_type="cuda", dtype=torch.bfloat16, enabled=True
194-
)
195-
196124
def test_mixed_precision_invalid_str(self) -> None:
197125
"""
198126
Test that an exception is raised with an invalid precision string
@@ -310,191 +238,6 @@ def test_stochastic_weight_averaging_update_freq(self) -> None:
310238
# 1 warmup + epoch 2 + epoch 3 = 2
311239
self.assertEqual(update_swa_mock.call_count, 2)
312240

313-
@skip_if_not_distributed
314-
@skip_if_not_gpu
315-
def test_stochastic_weight_averaging_fsdp(self) -> None:
316-
"""
317-
Test that swa params with FSDP is identical to non-FSDP swa
318-
"""
319-
spawn_multi_process(
320-
2,
321-
"nccl",
322-
self._test_stochastic_weight_averaging_fsdp,
323-
)
324-
325-
@staticmethod
326-
def _test_stochastic_weight_averaging_fsdp() -> None:
327-
class Net(torch.nn.Module):
328-
def __init__(self):
329-
super(Net, self).__init__()
330-
self.l1 = torch.nn.Linear(2, 2)
331-
self.b1 = torch.nn.BatchNorm1d(2)
332-
self.l2 = torch.nn.Linear(2, 2)
333-
334-
def forward(self, x):
335-
x = self.l1(x)
336-
x = self.b1(x)
337-
x = self.l2(x)
338-
return x
339-
340-
# so all ranks start with same initialized weights
341-
seed(0)
342-
device = init_from_env()
343-
my_module = Net()
344-
345-
auto_unit = DummyAutoUnit(
346-
module=deepcopy(my_module),
347-
device=device,
348-
step_lr_interval="step",
349-
swa_params=SWAParams(
350-
warmup_steps_or_epochs=1,
351-
step_or_epoch_update_freq=1,
352-
swalr_params=SWALRParams(
353-
anneal_steps_or_epochs=3,
354-
),
355-
averaging_method="ema",
356-
),
357-
)
358-
359-
auto_unit_fsdp = DummyAutoUnit(
360-
module=my_module,
361-
device=device,
362-
step_lr_interval="step",
363-
strategy=FSDPStrategy(),
364-
swa_params=SWAParams(
365-
warmup_steps_or_epochs=1,
366-
step_or_epoch_update_freq=1,
367-
swalr_params=SWALRParams(
368-
anneal_steps_or_epochs=3,
369-
),
370-
averaging_method="ema",
371-
),
372-
)
373-
374-
input_dim = 2
375-
dataset_len = 10
376-
batch_size = 2
377-
378-
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
379-
train(auto_unit, dataloader, max_epochs=1, max_steps_per_epoch=5)
380-
train(auto_unit_fsdp, dataloader, max_epochs=1, max_steps_per_epoch=5)
381-
382-
swa_params = list(auto_unit.swa_model.module.parameters())
383-
with FSDP.summon_full_params(auto_unit_fsdp.swa_model):
384-
swa_fsdp_params = list(auto_unit_fsdp.swa_model.module.parameters())
385-
386-
# Iterate and compare each parameter
387-
for p1, p2 in zip(swa_params, swa_fsdp_params, strict=True):
388-
torch.testing.assert_close(p2, p1, check_device=False)
389-
390-
@skip_if_not_gpu
391-
@patch("torch.autocast")
392-
def test_eval_mixed_precision_bf16(self, mock_autocast: MagicMock) -> None:
393-
"""
394-
Test that the mixed precision autocast context is called during evaluate when precision = bf16
395-
"""
396-
my_module = torch.nn.Linear(2, 2)
397-
auto_unit = DummyAutoUnit(
398-
module=my_module,
399-
precision="bf16",
400-
)
401-
402-
input_dim = 2
403-
dataset_len = 8
404-
batch_size = 2
405-
406-
eval_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
407-
evaluate(auto_unit, eval_dl)
408-
mock_autocast.assert_called_with(
409-
device_type="cuda", dtype=torch.bfloat16, enabled=True
410-
)
411-
412-
@skip_if_not_gpu
413-
@skip_if_not_distributed
414-
def test_no_sync(self) -> None:
415-
"""
416-
Test that the no_sync autocast context is correctly applied when using gradient accumulation
417-
"""
418-
spawn_multi_process(
419-
2,
420-
"nccl",
421-
self._test_ddp_no_sync,
422-
)
423-
spawn_multi_process(
424-
2,
425-
"nccl",
426-
self._test_fsdp_no_sync,
427-
)
428-
429-
@staticmethod
430-
def _test_ddp_no_sync() -> None:
431-
"""
432-
Test that the no_sync autocast context is correctly applied when using gradient accumulation and DDP
433-
"""
434-
435-
my_module = torch.nn.Linear(2, 2)
436-
437-
auto_unit = DummyAutoUnit(
438-
module=my_module,
439-
strategy=DDPStrategy(),
440-
gradient_accumulation_steps=2,
441-
)
442-
443-
dummy_iterator = iter(
444-
[(torch.ones(2, 2), torch.ones(2, 2)), (torch.ones(2, 2), torch.ones(2, 2))]
445-
)
446-
state = get_dummy_train_state()
447-
448-
# for the first step no_sync should be called since we accumulate gradients
449-
with patch.object(auto_unit.module, "no_sync") as no_sync_mock:
450-
auto_unit.train_step(
451-
state=state, data=auto_unit.get_next_train_batch(state, dummy_iterator)
452-
)
453-
no_sync_mock.assert_called_once()
454-
455-
auto_unit.train_progress.increment_step()
456-
# for the second step no_sync should not be called since we run optimizer step
457-
with patch.object(auto_unit.module, "no_sync") as no_sync_mock:
458-
auto_unit.train_step(
459-
state=state, data=auto_unit.get_next_train_batch(state, dummy_iterator)
460-
)
461-
no_sync_mock.assert_not_called()
462-
463-
@staticmethod
464-
def _test_fsdp_no_sync() -> None:
465-
"""
466-
Test that the no_sync autocast context is correctly applied when using gradient accumulation and FSDP
467-
"""
468-
device = init_from_env()
469-
my_module = torch.nn.Linear(2, 2).to(device)
470-
471-
auto_unit = DummyAutoUnit(
472-
module=my_module,
473-
device=device,
474-
strategy=FSDPStrategy(),
475-
gradient_accumulation_steps=2,
476-
)
477-
478-
dummy_iterator = iter(
479-
[(torch.ones(2, 2), torch.ones(2, 2)), (torch.ones(2, 2), torch.ones(2, 2))]
480-
)
481-
state = get_dummy_train_state()
482-
483-
# for the first step no_sync should be called since we accumulate gradients
484-
with patch.object(auto_unit.module, "no_sync") as no_sync_mock:
485-
auto_unit.train_step(
486-
state=state, data=auto_unit.get_next_train_batch(state, dummy_iterator)
487-
)
488-
no_sync_mock.assert_called_once()
489-
490-
auto_unit.train_progress.increment_step()
491-
# for the second step no_sync should not be called since we run optimizer step
492-
with patch.object(auto_unit.module, "no_sync") as no_sync_mock:
493-
auto_unit.train_step(
494-
state=state, data=auto_unit.get_next_train_batch(state, dummy_iterator)
495-
)
496-
no_sync_mock.assert_not_called()
497-
498241
def test_move_data_to_device(self) -> None:
499242
"""
500243
Test that move_data_to_device is called
@@ -746,53 +489,6 @@ def test_auto_unit_timing_predict(self) -> None:
746489
timer=Timer(),
747490
)
748491

749-
@skip_if_not_gpu
750-
@patch("torch.autocast")
751-
def test_predict_mixed_precision_fp16(self, mock_autocast: MagicMock) -> None:
752-
"""
753-
Test that the mixed precision autocast context is called during predict when precision = fp16
754-
"""
755-
my_module = torch.nn.Linear(2, 2)
756-
auto_unit = AutoPredictUnit(module=my_module, precision="fp16")
757-
758-
input_dim = 2
759-
dataset_len = 8
760-
batch_size = 2
761-
762-
predict_dl = generate_random_iterable_dataloader(
763-
dataset_len, input_dim, batch_size
764-
)
765-
predict(auto_unit, predict_dl)
766-
mock_autocast.assert_called_with(
767-
device_type="cuda", dtype=torch.float16, enabled=True
768-
)
769-
770-
@unittest.skipUnless(
771-
condition=COMPILE_AVAIL,
772-
reason="This test needs PyTorch 1.13 or greater to run.",
773-
)
774-
@skip_if_not_gpu
775-
@patch("torch.compile")
776-
def test_compile_predict(self, mock_dynamo: MagicMock) -> None:
777-
"""
778-
e2e torch compile on predict
779-
"""
780-
my_module = torch.nn.Linear(2, 2)
781-
auto_unit = AutoPredictUnit(
782-
module=my_module,
783-
torch_compile_params=TorchCompileParams(backend="eager"),
784-
)
785-
786-
input_dim = 2
787-
dataset_len = 8
788-
batch_size = 2
789-
790-
predict_dl = generate_random_iterable_dataloader(
791-
dataset_len, input_dim, batch_size
792-
)
793-
predict(auto_unit, predict_dl)
794-
mock_dynamo.assert_called()
795-
796492
def test_auto_predict_unit_timing_predict(self) -> None:
797493
"""
798494
Test auto timing in AutoUnit for predict

0 commit comments

Comments
 (0)