Skip to content

Commit bf7747e

Browse files
rahulsingh-intelguangyey
authored andcommitted
Tests Generelization for multiple accelerator devices (pytorch#139184)
Motivation: Generalize unit tests so that can be executed for cuda and non cuda devices. Depedency : pytorch#133209 Merged now. There was a pytorch#135242 for these changes and closed due to in correct commits. I have incoroprated the changes as suggested in comments. @kwen2501 @zeshengzong Please review the changes. Pull Request resolved: pytorch#139184 Approved by: https://github.com/kwen2501 Co-authored-by: Yu, Guangye <[email protected]>
1 parent 2e1ea85 commit bf7747e

24 files changed

+414
-379
lines changed

test/distributed/fsdp/test_checkpoint_wrapper.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616
OffloadWrapper,
1717
)
1818
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
19+
from torch.testing._internal.common_fsdp import get_devtype
1920
from torch.testing._internal.common_utils import run_tests, TestCase
2021
from torch.utils.checkpoint import checkpoint
2122

2223

2324
_SAVED_PREFIX = "_saved_"
2425
GRAD_FN_NEXT_FUNCTIONS = "next_functions"
2526

27+
device_type = torch.device(get_devtype())
28+
2629

2730
class CheckpointWrapperTest(TestCase):
2831
def test_load_activation_checkpointed_module(self):
@@ -130,7 +133,7 @@ def get_ctx_mgrs():
130133
m(torch.randn(2, 1)).sum().backward()
131134
self.assertEqual(2, count)
132135

133-
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
136+
@unittest.skip
134137
def test_checkpoint_wrapper_parity(self):
135138
"""
136139
Tests that using checkpoint_wrapper or the functional
@@ -155,9 +158,11 @@ def __init__(
155158
self.use_reentrant = use_reentrant
156159
wrp = partial(
157160
checkpoint_wrapper,
158-
checkpoint_impl=CheckpointImpl.REENTRANT
159-
if use_reentrant
160-
else CheckpointImpl.NO_REENTRANT,
161+
checkpoint_impl=(
162+
CheckpointImpl.REENTRANT
163+
if use_reentrant
164+
else CheckpointImpl.NO_REENTRANT
165+
),
161166
)
162167
for _ in range(self.n):
163168
l = nn.Sequential(
@@ -184,12 +189,12 @@ def test(use_checkpointing, use_wrapper, use_reentrant):
184189
use_checkpointing,
185190
use_wrapper=use_wrapper,
186191
use_reentrant=use_reentrant,
187-
).cuda()
188-
x = torch.randn(10000, 256, requires_grad=True).cuda()
189-
torch.cuda.reset_peak_memory_stats()
192+
).to(device_type.type)
193+
x = torch.randn(10000, 256, requires_grad=True).to(device_type.type)
194+
torch.get_device_module(device_type.type).reset_peak_memory_stats()
190195
loss = a(x).sum()
191196
loss.backward()
192-
return torch.cuda.max_memory_allocated()
197+
return torch.get_device_module(device_type.type).max_memory_allocated()
193198

194199
functional_no_reentrant = test(
195200
use_checkpointing=True, use_wrapper=False, use_reentrant=False
@@ -333,13 +338,12 @@ def test_fqn(self):
333338
for fqn, _ in lin.named_parameters():
334339
self.assertTrue(fqn in state_dict, msg=f"{fqn} not in state_dict.")
335340

336-
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
337341
def test_checkpoint_wrapper_cpu_offload(self):
338342
model = nn.Sequential(
339343
nn.Linear(10, 10),
340344
nn.Linear(10, 10),
341345
nn.Linear(10, 10),
342-
).cuda()
346+
).to(device_type.type)
343347

344348
# Patch saved_tensor_hooks to make the unpack keep the tensor on CPU for
345349
# testing, otherwise the tensor access during the DFS will cause orig
@@ -358,7 +362,7 @@ def testing_cpu_offload_unpack_hook(packed):
358362

359363
model = offload_wrapper(model)
360364

361-
inp = torch.randn(3, 10, device="cuda")
365+
inp = torch.randn(3, 10, device=device_type.type)
362366
loss = model(inp).sum()
363367

364368
# All autograd saved tensors should be offloaded to CPU.

test/distributed/fsdp/test_distributed_checkpoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
99
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
1010
from torch.distributed.fsdp.wrap import enable_wrap, wrap
11+
from torch.testing._internal.common_device_type import instantiate_device_type_tests
1112
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
1213
from torch.testing._internal.common_fsdp import FSDPTest, SkipModel
1314
from torch.testing._internal.common_utils import (
14-
instantiate_parametrized_tests,
1515
parametrize,
1616
run_tests,
1717
TEST_WITH_DEV_DBG_ASAN,
@@ -85,7 +85,7 @@ def test_distributed_checkpoint(self, state_dict_type) -> None:
8585
# TODO: add resharding test case.
8686

8787

88-
instantiate_parametrized_tests(TestDistributedCheckpoint)
89-
88+
devices = ("cuda", "hpu")
89+
instantiate_device_type_tests(TestDistributedCheckpoint, globals(), only_for=devices)
9090
if __name__ == "__main__":
9191
run_tests()

test/distributed/fsdp/test_fsdp_apply.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
import torch.distributed as dist
77
import torch.nn as nn
88
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
9+
from torch.testing._internal.common_device_type import instantiate_device_type_tests
910
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
1011
from torch.testing._internal.common_fsdp import (
1112
DEVICEInitMode,
1213
FSDPInitMode,
1314
FSDPTest,
15+
get_devtype,
1416
NestedWrappedModule,
1517
TransformerWithSharedParams,
1618
)
@@ -28,6 +30,8 @@
2830
)
2931
sys.exit(0)
3032

33+
device_type = torch.device(get_devtype())
34+
3135

3236
class TestApply(FSDPTest):
3337
@property
@@ -67,37 +71,45 @@ def _check_apply(self, fsdp):
6771
def test_nested_module_apply(self):
6872
"""Tests that ``apply()`` modifies parameter values in-place on a
6973
non-FSDP-root nested FSDP-wrapped model."""
74+
fsdp_kwargs = {"device_id": device_type.type}
7075
nested_wrapped_module = NestedWrappedModule.init(
7176
self.process_group,
7277
FSDPInitMode.RECURSIVE,
7378
DEVICEInitMode.DEVICE_AFTER,
79+
fsdp_kwargs=fsdp_kwargs,
7480
)
7581
self._check_apply(nested_wrapped_module)
7682

7783
@skip_if_lt_x_gpu(2)
7884
def test_transformer_module_apply(self):
7985
"""Tests that ``apply()`` modifies parameter values in-place on an
8086
FSDP-wrapped transformer model with shared parameters."""
87+
fsdp_kwargs = {"device_id": device_type.type}
8188
transformer = TransformerWithSharedParams.init(
8289
self.process_group,
8390
FSDPInitMode.RECURSIVE,
8491
DEVICEInitMode.DEVICE_AFTER,
92+
fsdp_kwargs=fsdp_kwargs,
8593
)
8694
self._check_apply(transformer)
8795

8896
@skip_if_lt_x_gpu(2)
8997
def test_apply_in_summon_raises_error(self):
9098
"""Tests that calling ``apply()`` on an FSDP instance inside the
9199
``summon_full_params()`` context raises an error."""
100+
fsdp_kwargs = {"device_id": device_type.type}
92101
transformer = TransformerWithSharedParams.init(
93102
self.process_group,
94103
FSDPInitMode.RECURSIVE,
95104
DEVICEInitMode.DEVICE_AFTER,
105+
fsdp_kwargs=fsdp_kwargs,
96106
)
97107
with transformer.summon_full_params(transformer):
98108
with self.assertRaisesRegex(ValueError, "expected to be in states"):
99109
transformer.apply(self._init_linear_weights)
100110

101111

112+
devices = ("cuda", "hpu")
113+
instantiate_device_type_tests(TestApply, globals(), only_for=devices)
102114
if __name__ == "__main__":
103115
run_tests()

test/distributed/fsdp/test_fsdp_backward_prefetch.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
)
1717
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
1818
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
19-
from torch.testing._internal.common_fsdp import FSDPTest
19+
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
2020
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
2121

2222

23+
device_type = torch.device(get_devtype())
24+
2325
NUM_ITERS = 2
2426
DECODER_PARAM_FQNS = [
2527
"decoder.layers.{index}.self_attn.in_proj_weight",
@@ -81,14 +83,13 @@ def world_size(self):
8183
def _dist_train(self, backward_prefetch=BackwardPrefetch.BACKWARD_PRE):
8284
rank = self.rank
8385
orig_get_handle_to_prefetch = _get_handle_to_prefetch
84-
8586
torch.manual_seed(0)
8687
policy = ModuleWrapPolicy(
8788
{nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
8889
)
8990
model = FSDP(
90-
nn.Transformer(d_model=1024, nhead=8, device="cuda"),
91-
device_id=torch.cuda.current_device(),
91+
nn.Transformer(d_model=1024, nhead=8, device=device_type),
92+
device_id=device_type.type,
9293
auto_wrap_policy=policy,
9394
use_orig_params=True,
9495
backward_prefetch=backward_prefetch,
@@ -97,8 +98,8 @@ def _dist_train(self, backward_prefetch=BackwardPrefetch.BACKWARD_PRE):
9798

9899
# prepare input
99100
torch.manual_seed(rank + 1)
100-
src = torch.randn((10, 1, 1024), device="cuda")
101-
tgt = torch.randn((20, 1, 1024), device="cuda")
101+
src = torch.randn((10, 1, 1024), device=device_type)
102+
tgt = torch.randn((20, 1, 1024), device=device_type)
102103

103104
# monkey patch
104105
all_handle_fqns: List[List[str]] = []

0 commit comments

Comments
 (0)