Skip to content

Commit 9d9ecdb

Browse files
weishi-dengpytorchmergebot
authored andcommitted
[xpu][feature] Enable triton online softmax kernels on XPU. (pytorch#163251)
This pr is to enable triton online softmax kernels for xpu devices, so we add a device check in prepare_softmax_extra_check. Pull Request resolved: pytorch#163251 Approved by: https://github.com/etaf, https://github.com/EikanWang, https://github.com/mlazos
1 parent d24276f commit 9d9ecdb

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

test/inductor/test_online_softmax.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
IS_LINUX,
1515
parametrize,
1616
)
17-
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA_AND_TRITON
17+
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, HAS_TRITON
1818

1919

2020
DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"
@@ -138,8 +138,14 @@ def test_prepare_softmax(self, dim, nrow):
138138
self.assertTrue(same(ref, act, tol=1e-2))
139139

140140
if nrow == 2048 and dim == 0:
141+
num_kernels = 2
142+
# Note: split reduction is not triggered for this shape on some xpu devices.
143+
# check "num_splits" for more details
144+
if GPU_TYPE == "xpu":
145+
num_kernels = 1
146+
141147
# split reduction is triggered. We have multiple kernels
142-
self.assertTrue(code.count("def triton") >= 2)
148+
self.assertTrue(code.count("def triton") >= num_kernels)
143149
else:
144150
if nrow == 2 and dim == 0:
145151
# persistent reduction triggered
@@ -310,5 +316,5 @@ def f(x, y):
310316
instantiate_parametrized_tests(TestOnlineSoftmax)
311317

312318
if __name__ == "__main__":
313-
if IS_LINUX and HAS_CUDA_AND_TRITON:
319+
if IS_LINUX and HAS_GPU and HAS_TRITON:
314320
run_tests()

0 commit comments

Comments
 (0)