Skip to content

Commit 1155c53

Browse files
daisydenguangyey
authored andcommitted
Port three dynamo test to Intel GPU (pytorch#156575)
For pytorch#114850, we will port test cases to Intel GPU. Two dynamo test files were ported in PR [pytorch#156056](pytorch#156056). In this PR we will port 3 more dynamo test files. We could enable Intel GPU with following methods and try the best to keep the original code styles: - instantiate_device_type_tests() - use "torch.accelerator.current_accelerator()" to determine the accelerator backend - added XPU support in decorators like @requires_gpu - enabled XPU for some test path. Pull Request resolved: pytorch#156575 Approved by: https://github.com/guangyey, https://github.com/jansel Co-authored-by: Yu, Guangye <[email protected]>
1 parent 51853b3 commit 1155c53

File tree

4 files changed

+41
-27
lines changed

4 files changed

+41
-27
lines changed

test/dynamo/test_misc.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
skipIfNNModuleInlined,
9191
skipIfWindows,
9292
TEST_HPU,
93+
TEST_XPU,
9394
wrapDeterministicFlagAPITest,
9495
)
9596
from torch.testing._internal.jit_utils import JitTestCase
@@ -6904,7 +6905,7 @@ def guard_failures(failure):
69046905
self.assertTrue(guard_failure is not None)
69056906
self.assertIn("""tensor 'rank' size mismatch at index 0""", guard_failure[0])
69066907

6907-
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
6908+
@unittest.skipIf(not TEST_CUDA and not TEST_XPU, "Test requires CUDA or XPU.")
69086909
def test_symint_as_device_kwarg_non_strict_export(self):
69096910
class Mod(torch.nn.Module):
69106911
def forward(self, x):
@@ -12771,7 +12772,7 @@ def forward(self, query, key, value):
1277112772

1277212773
def test_torch_device_is_available(self, device):
1277312774
def fn(x):
12774-
if TEST_HPU or TEST_CUDA:
12775+
if torch.accelerator.is_available():
1277512776
return x + 1
1277612777
else:
1277712778
return x - 1
@@ -12874,27 +12875,23 @@ def f(rank):
1287412875
def test_cuda_set_device(self, device):
1287512876
def fn():
1287612877
a = torch.ones(2, device=device)
12877-
torch.cuda.set_device(1)
12878+
torch.get_device_module(device).set_device(1)
1287812879
return a + 1
1287912880

12880-
with torch.cuda.device(0):
12881+
with torch.get_device_module(device).device(0):
1288112882
counter = CompileCounter()
1288212883
opt_fn = torch.compile(fn, backend=counter)
1288312884
res = opt_fn()
12884-
self.assertEqual(res.device.type, "cuda")
12885+
self.assertEqual(res.device.type, device)
1288512886
self.assertEqual(res.device.index, 0)
1288612887
self.assertEqual(counter.frame_count, 2)
1288712888

12888-
def test_torch_device_python_type(self):
12889+
def test_torch_device_python_type(self, device):
12890+
device_type = torch.device(device).type
1288912891
for device, device_type, index in [
1289012892
("cpu", "cpu", None),
12891-
("cuda:0", "cuda", 0),
12892-
("hpu:0", "hpu", 0),
12893+
(device, device_type, 0),
1289312894
]:
12894-
if (device == "cuda:0" and not TEST_CUDA) or (
12895-
device == "hpu:0" and not TEST_HPU
12896-
):
12897-
continue
1289812895

1289912896
def fn(target):
1290012897
target_device = target.device
@@ -12956,8 +12953,10 @@ def f(actions, n_act, epsilon=0.1):
1295612953
f(x, y)
1295712954

1295812955

12959-
devices = ("cuda", "hpu")
12960-
instantiate_device_type_tests(MiscTestsDevice, globals(), only_for=devices)
12956+
devices = ("cuda", "hpu", "xpu")
12957+
instantiate_device_type_tests(
12958+
MiscTestsDevice, globals(), only_for=devices, allow_xpu=True
12959+
)
1296112960
if __name__ == "__main__":
1296212961
from torch._dynamo.test_case import run_tests
1296312962

test/dynamo/test_model_output.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch._dynamo.testing
88
from torch._dynamo.testing import same
99
from torch.testing._internal.common_device_type import instantiate_device_type_tests
10-
from torch.testing._internal.common_utils import TEST_HPU, TestCase
10+
from torch.testing._internal.common_utils import TestCase
1111

1212

1313
try:
@@ -359,11 +359,11 @@ def forward(
359359
)
360360

361361

362-
devices = ["cpu", "cuda"]
363-
if TEST_HPU:
364-
devices.append("hpu")
362+
devices = ["cpu", "cuda", "xpu", "hpu"]
365363

366-
instantiate_device_type_tests(TestModelOutputBert, globals(), only_for=devices)
364+
instantiate_device_type_tests(
365+
TestModelOutputBert, globals(), only_for=devices, allow_xpu=True
366+
)
367367

368368
if __name__ == "__main__":
369369
from torch._dynamo.test_case import run_tests

test/dynamo/test_modes.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,16 @@
1212
_push_on_torch_function_stack,
1313
)
1414
from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode
15-
from torch.testing._internal.triton_utils import requires_cuda
15+
from torch.testing._internal.triton_utils import requires_gpu
1616
from torch.utils._device import DeviceContext
1717
from torch.utils._python_dispatch import TorchDispatchMode
1818

1919

20+
device_type = (
21+
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
22+
)
23+
24+
2025
class TestMode(BaseTorchFunctionMode):
2126
def __torch_function__(self, func, types, args, kwargs=None):
2227
if not kwargs:
@@ -613,12 +618,12 @@ def func(a):
613618

614619
func(torch.randn(3))
615620

616-
@requires_cuda
621+
@requires_gpu
617622
def test_flex_attention(self):
618623
import torch
619624
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
620625

621-
torch.set_default_device("cuda")
626+
torch.set_default_device(device_type)
622627

623628
flex_attention = torch.compile(flex_attention, dynamic=False)
624629

@@ -628,7 +633,9 @@ def prefix_lm(b, h, q, kv):
628633
return prefix_lengths[b] >= kv
629634

630635
# This runs in fullgraph already
631-
create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True)
636+
create_block_mask(
637+
prefix_lm, 8, None, 512, 512, _compile=True, device=device_type
638+
)
632639

633640
def test_register_hook(self):
634641
import functools

test/dynamo/test_package.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
instantiate_parametrized_tests,
1717
parametrize,
1818
)
19-
from torch.testing._internal.inductor_utils import HAS_CUDA
19+
from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_XPU
2020

2121

2222
@functorch_config.patch("bundled_autograd_cache", True)
@@ -28,10 +28,13 @@ def path(self):
2828
return path
2929

3030
@parametrize("backend", ("eager", "inductor"))
31-
@parametrize("device", ("cpu", "cuda"))
31+
@parametrize("device", ("cpu", "cuda", "xpu"))
3232
def test_basic_fn(self, backend, device):
3333
if device == "cuda" and not HAS_CUDA:
3434
raise unittest.SkipTest("Requires CUDA/Triton")
35+
if device == "xpu" and not HAS_XPU:
36+
raise unittest.SkipTest("Requires XPU/Triton")
37+
3538
ctx = DiskDynamoStore()
3639

3740
def fn(x):
@@ -69,10 +72,12 @@ def fn(x):
6972
self.assertEqual(expected, compiled_fn(*args))
7073

7174
@parametrize("backend", ("eager", "inductor"))
72-
@parametrize("device", ("cpu", "cuda"))
75+
@parametrize("device", ("cpu", "cuda", "xpu"))
7376
def test_graph_break_bomb(self, backend, device):
7477
if device == "cuda" and not HAS_CUDA:
7578
raise unittest.SkipTest("Requires CUDA/Triton")
79+
if device == "xpu" and not HAS_XPU:
80+
raise unittest.SkipTest("Requires XPU/Triton")
7681

7782
ctx = DiskDynamoStore()
7883

@@ -131,10 +136,13 @@ def guard_filter_fn(guards):
131136
compiled_fn(torch.tensor(N), 0, N - 1)
132137

133138
@parametrize("backend", ("eager", "inductor"))
134-
@parametrize("device", ("cpu", "cuda"))
139+
@parametrize("device", ("cpu", "cuda", "xpu"))
135140
def test_dynamic_shape(self, backend, device):
136141
if device == "cuda" and not HAS_CUDA:
137142
raise unittest.SkipTest("Requires CUDA/Triton")
143+
if device == "xpu" and not HAS_XPU:
144+
raise unittest.SkipTest("Requires XPU/Triton")
145+
138146
ctx = DiskDynamoStore()
139147

140148
def fn(x):

0 commit comments

Comments
 (0)