@@ -379,7 +379,7 @@ def test_execution_trace_env_disabled(self, device):
379379 sys .version_info >= (3 , 12 ), "torch.compile is not supported on python 3.12+"
380380 )
381381 @unittest .skipIf (
382- ( not has_triton ()) or ( not TEST_CUDA and not TEST_XPU ),
382+ not ( has_triton () and ( TEST_CUDA or TEST_XPU ) ),
383383 "need triton and device(CUDA or XPU) availability to run" ,
384384 )
385385 @skipCPUIf (True , "skip CPU device for testing profiling triton" )
@@ -438,7 +438,7 @@ def fn(a, b, c):
438438 sys .version_info >= (3 , 12 ), "torch.compile is not supported on python 3.12+"
439439 )
440440 @unittest .skipIf (
441- ( not has_triton ()) or ( not TEST_CUDA and not TEST_XPU ),
441+ not ( has_triton () and ( TEST_CUDA or TEST_XPU ) ),
442442 "need triton and device(CUDA or XPU) availability to run" ,
443443 )
444444 @skipCPUIf (True , "skip CPU device for testing profiling triton" )
@@ -500,8 +500,8 @@ def fn(a, b, c):
500500
501501 @unittest .skipIf (IS_WINDOWS , "torch.compile does not support WINDOWS" )
502502 @unittest .skipIf (
503- ( not has_triton ()) or ( not TEST_CUDA ),
504- "need triton and device CUDA availability to run" ,
503+ not ( has_triton () and ( TEST_CUDA or TEST_XPU ) ),
504+ "need triton and device( CUDA or XPU) availability to run" ,
505505 )
506506 @skipCPUIf (True , "skip CPU device for testing profiling triton" )
507507 def test_triton_fx_graph_with_et (self , device ):
@@ -510,8 +510,6 @@ def test_triton_fx_graph_with_et(self, device):
510510
511511 PyCodeCache .cache_clear (purge = True )
512512
513- import os
514-
515513 @torchdynamo .optimize ("inductor" )
516514 def fn (a , b , c ):
517515 x = torch .nn .functional .linear (a , b )
@@ -520,15 +518,14 @@ def fn(a, b, c):
520518 return x .cos ()
521519
522520 a , b , c = (
523- torch .randn (4 , 4 , requires_grad = False ).to (torch .device ("cuda:0" ))
521+ torch .randn (4 , 4 , requires_grad = False ).to (torch .device (device ))
524522 for _ in range (3 )
525523 )
526524
527- inputs = [a , b , c ]
528525 with torch ._inductor .config .patch (
529526 compile_threads = 1 , fx_graph_cache = False , fx_graph_remote_cache = False
530527 ):
531- fn (* inputs )
528+ fn (a , b , c )
532529
533530 et = ExecutionTraceObserver ()
534531 with tempfile .NamedTemporaryFile (
@@ -546,7 +543,7 @@ def fn(a, b, c):
546543 ) as p :
547544 for idx in range (10 ):
548545 with record_function (f"## LOOP { idx } ##" ):
549- fn (* inputs )
546+ fn (a , b , c )
550547 p .step ()
551548
552549 et_path = p .execution_trace_observer .get_output_file_path ()
@@ -576,31 +573,31 @@ def fn(a, b, c):
576573 if len (fx_graph ) > 0 :
577574 assert (
578575 fx_graph [0 ]
579- == '# %mm : Tensor "f32[4, 4][4, 1]cuda:0 " = PlaceHolder[target=mm]'
576+ == f '# %mm : Tensor "f32[4, 4][4, 1]{ device } " = PlaceHolder[target=mm]'
580577 )
581578 assert (
582579 fx_graph [1 ]
583- == '# %arg2_1 : Tensor "f32[4, 4][4, 1]cuda:0 " = PlaceHolder[target=arg2_1]'
580+ == f '# %arg2_1 : Tensor "f32[4, 4][4, 1]{ device } " = PlaceHolder[target=arg2_1]'
584581 )
585582 assert (
586583 fx_graph [2 ]
587- == '# %sin : Tensor "f32[4, 4][4, 1]cuda:0 "[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {})' # noqa: B950
584+ == f '# %sin : Tensor "f32[4, 4][4, 1]{ device } "[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {{} })' # noqa: B950
588585 )
589586 assert (
590587 fx_graph [3 ]
591- == '# %permute_1 : Tensor "f32[4, 4][1, 4]cuda:0 "[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {})' # noqa: B950
588+ == f '# %permute_1 : Tensor "f32[4, 4][1, 4]{ device } "[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {{} })' # noqa: B950
592589 )
593590 assert (
594591 fx_graph [4 ]
595- == '# %mul : Tensor "f32[4, 4][4, 1]cuda:0 "[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {})' # noqa: B950
592+ == f '# %mul : Tensor "f32[4, 4][4, 1]{ device } "[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {{} })' # noqa: B950
596593 )
597594 assert (
598595 fx_graph [5 ]
599- == '# %add : Tensor "f32[4, 4][1, 4]cuda:0 "[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {})' # noqa: B950
596+ == f '# %add : Tensor "f32[4, 4][1, 4]{ device } "[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {{} })' # noqa: B950
600597 )
601598 assert (
602599 fx_graph [6 ]
603- == '# %cos : Tensor "f32[4, 4][1, 4]cuda:0 "[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {})' # noqa: B950
600+ == f '# %cos : Tensor "f32[4, 4][1, 4]{ device } "[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {{} })' # noqa: B950
604601 )
605602 assert fx_graph [7 ] == "# return %cos"
606603 os .remove (file_path )
0 commit comments