Skip to content

Commit da892d2

Browse files
PawelSwider2000pytorchmergebot
authored andcommitted
Enable test_triton_fx_graph_with_et_xpu to run with XPU (pytorch#169181)
Change hardcoded `"cuda:0"` to `device` param to allow to running `test_triton_fx_graph_with_et` on different devices, especially test now passed on XPU. Simplify skip conditions and make minor refactor. Fixes intel/torch-xpu-ops#2040 Pull Request resolved: pytorch#169181 Approved by: https://github.com/guangyey, https://github.com/jansel, https://github.com/EikanWang
1 parent 4816fd9 commit da892d2

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

test/profiler/test_execution_trace.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def test_execution_trace_env_disabled(self, device):
379379
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
380380
)
381381
@unittest.skipIf(
382-
(not has_triton()) or (not TEST_CUDA and not TEST_XPU),
382+
not (has_triton() and (TEST_CUDA or TEST_XPU)),
383383
"need triton and device(CUDA or XPU) availability to run",
384384
)
385385
@skipCPUIf(True, "skip CPU device for testing profiling triton")
@@ -438,7 +438,7 @@ def fn(a, b, c):
438438
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
439439
)
440440
@unittest.skipIf(
441-
(not has_triton()) or (not TEST_CUDA and not TEST_XPU),
441+
not (has_triton() and (TEST_CUDA or TEST_XPU)),
442442
"need triton and device(CUDA or XPU) availability to run",
443443
)
444444
@skipCPUIf(True, "skip CPU device for testing profiling triton")
@@ -500,8 +500,8 @@ def fn(a, b, c):
500500

501501
@unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS")
502502
@unittest.skipIf(
503-
(not has_triton()) or (not TEST_CUDA),
504-
"need triton and device CUDA availability to run",
503+
not (has_triton() and (TEST_CUDA or TEST_XPU)),
504+
"need triton and device(CUDA or XPU) availability to run",
505505
)
506506
@skipCPUIf(True, "skip CPU device for testing profiling triton")
507507
def test_triton_fx_graph_with_et(self, device):
@@ -510,8 +510,6 @@ def test_triton_fx_graph_with_et(self, device):
510510

511511
PyCodeCache.cache_clear(purge=True)
512512

513-
import os
514-
515513
@torchdynamo.optimize("inductor")
516514
def fn(a, b, c):
517515
x = torch.nn.functional.linear(a, b)
@@ -520,15 +518,14 @@ def fn(a, b, c):
520518
return x.cos()
521519

522520
a, b, c = (
523-
torch.randn(4, 4, requires_grad=False).to(torch.device("cuda:0"))
521+
torch.randn(4, 4, requires_grad=False).to(torch.device(device))
524522
for _ in range(3)
525523
)
526524

527-
inputs = [a, b, c]
528525
with torch._inductor.config.patch(
529526
compile_threads=1, fx_graph_cache=False, fx_graph_remote_cache=False
530527
):
531-
fn(*inputs)
528+
fn(a, b, c)
532529

533530
et = ExecutionTraceObserver()
534531
with tempfile.NamedTemporaryFile(
@@ -546,7 +543,7 @@ def fn(a, b, c):
546543
) as p:
547544
for idx in range(10):
548545
with record_function(f"## LOOP {idx} ##"):
549-
fn(*inputs)
546+
fn(a, b, c)
550547
p.step()
551548

552549
et_path = p.execution_trace_observer.get_output_file_path()
@@ -576,31 +573,31 @@ def fn(a, b, c):
576573
if len(fx_graph) > 0:
577574
assert (
578575
fx_graph[0]
579-
== '# %mm : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=mm]'
576+
== f'# %mm : Tensor "f32[4, 4][4, 1]{device}" = PlaceHolder[target=mm]'
580577
)
581578
assert (
582579
fx_graph[1]
583-
== '# %arg2_1 : Tensor "f32[4, 4][4, 1]cuda:0" = PlaceHolder[target=arg2_1]'
580+
== f'# %arg2_1 : Tensor "f32[4, 4][4, 1]{device}" = PlaceHolder[target=arg2_1]'
584581
)
585582
assert (
586583
fx_graph[2]
587-
== '# %sin : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {})' # noqa: B950
584+
== f'# %sin : Tensor "f32[4, 4][4, 1]{device}"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {{}})' # noqa: B950
588585
)
589586
assert (
590587
fx_graph[3]
591-
== '# %permute_1 : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {})' # noqa: B950
588+
== f'# %permute_1 : Tensor "f32[4, 4][1, 4]{device}"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {{}})' # noqa: B950
592589
)
593590
assert (
594591
fx_graph[4]
595-
== '# %mul : Tensor "f32[4, 4][4, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {})' # noqa: B950
592+
== f'# %mul : Tensor "f32[4, 4][4, 1]{device}"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {{}})' # noqa: B950
596593
)
597594
assert (
598595
fx_graph[5]
599-
== '# %add : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {})' # noqa: B950
596+
== f'# %add : Tensor "f32[4, 4][1, 4]{device}"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {{}})' # noqa: B950
600597
)
601598
assert (
602599
fx_graph[6]
603-
== '# %cos : Tensor "f32[4, 4][1, 4]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {})' # noqa: B950
600+
== f'# %cos : Tensor "f32[4, 4][1, 4]{device}"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {{}})' # noqa: B950
604601
)
605602
assert fx_graph[7] == "# return %cos"
606603
os.remove(file_path)

0 commit comments

Comments
 (0)