Skip to content

Commit 8b3fb1e

Browse files
authored
[AMD] Fixing test_aot.py::test_gluon_kernel for gfx1250 (#8958)
Test `python/test/unit/tools/test_aot.py::test_gluon_kernel` fails on gfx1250 since it has some properties (64 threads/wave) hard-coded to gfx942, however pulls compilation target from the driver (despite this being aot-compilation). With gluon being somewhat architecture specific, this PR modifies aot testing to specify the architecture which gluon is targeting; fixes the above-mentioned test failure for gfx1250 (32 threads/wave).
1 parent df930ef commit 8b3fb1e

File tree

1 file changed

+24
-32
lines changed

1 file changed

+24
-32
lines changed

python/test/unit/tools/test_aot.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ def kernel(C, A, B, M, N, K,
6363
tl.store(c_ptrs, c)
6464
"""
6565

66-
gluon_kernel_src = """
66+
67+
def get_gluon_kernel_src(threads_per_warp):
68+
return f"""
6769
from triton.experimental import gluon
6870
from triton.experimental.gluon import language as gl
6971
@@ -77,12 +79,13 @@ def kernel(
7779
BLOCK_N: gl.constexpr,
7880
BLOCK_K: gl.constexpr
7981
):
80-
layout: gl.constexpr = gl.BlockedLayout(size_per_thread=[1], threads_per_warp=[64], warps_per_cta=[1], order=[0])
82+
layout: gl.constexpr = gl.BlockedLayout(size_per_thread=[1], threads_per_warp=[{threads_per_warp}], warps_per_cta=[1], order=[0])
8183
offs = gl.arange(0, 64, layout=layout)
8284
a = gl.load(A + offs)
8385
gl.store(B + offs, a)
8486
"""
8587

88+
8689
test_utils_src = """
8790
#include <cuda.h>
8891
#include <stdio.h>
@@ -215,34 +218,21 @@ def write_triton_kernels(dir, src, util_src):
215218
return kernel_path
216219

217220

218-
def _compile_kernel(dir, signature, kernel_name, out_name, out_path, num_warps, grid, kernel_path):
221+
def _compile_kernel(dir, signature, kernel_name, out_name, out_path, num_warps, grid, kernel_path, target=None):
219222
compiler_path = os.path.join(triton.tools.__path__[0], "compile.py")
220-
221-
subprocess.run(
222-
[
223-
sys.executable,
224-
compiler_path,
225-
"-n",
226-
kernel_name,
227-
"--signature",
228-
signature,
229-
"--out-name",
230-
out_name,
231-
"-o",
232-
out_path,
233-
"-w",
234-
str(num_warps),
235-
"-g",
236-
grid,
237-
kernel_path,
238-
],
239-
check=True,
240-
cwd=dir,
241-
)
223+
cmd_args = [
224+
sys.executable, compiler_path, "-n", kernel_name, "--signature", signature, "--out-name", out_name, "-o",
225+
out_path, "-w",
226+
str(num_warps), "-g", grid
227+
]
228+
if target:
229+
cmd_args.extend(["-t", "%s:%s:%i" % (target.backend, target.arch, target.warp_size)])
230+
cmd_args.append(kernel_path)
231+
subprocess.run(cmd_args, check=True, cwd=dir)
242232

243233

244234
# Edge case kernel with no specialization
245-
def compile_aot_kernel_no_specialization(dir, kernel_path, dtype, BM, BN, BK):
235+
def compile_aot_kernel_no_specialization(dir, kernel_path, dtype, BM, BN, BK, target=None):
246236
# compile all desired configs
247237
sig = f"*fp32, *{dtype}, *{dtype}, i32, i32, i32, i32, i32, i32, i32, i32, i32, {BM}, {BN}, {BK}"
248238
name = f"matmul_{dtype}"
@@ -256,10 +246,11 @@ def compile_aot_kernel_no_specialization(dir, kernel_path, dtype, BM, BN, BK):
256246
num_warps=1,
257247
grid=grid,
258248
kernel_path=kernel_path,
249+
target=target,
259250
)
260251

261252

262-
def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints):
253+
def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints, target=None):
263254
# compile all desired configs
264255
for ha, hb in ha_hb_hints:
265256
sig = f"*fp32:16, *{dtype}:16, *{dtype}:16, i32, i32, i32, i32{ha}, i32:1, i32{hb}, i32:1, i32:16, i32:1, {BM}, {BN}, {BK}"
@@ -274,6 +265,7 @@ def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints):
274265
num_warps=1,
275266
grid=grid,
276267
kernel_path=kernel_path,
268+
target=target,
277269
)
278270

279271

@@ -492,13 +484,13 @@ def test_ttgir_to_asm():
492484
assert '.wavefront_size: 64' in amdgcn
493485

494486

495-
def test_gluon_kernel():
496-
if not is_hip():
497-
pytest.skip("Gluon kernel is only supported on HIP")
487+
@pytest.mark.parametrize("target", [GPUTarget("hip", "gfx942", 64), GPUTarget("hip", "gfx1250", 32)])
488+
@pytest.mark.skipif(not is_hip(), reason="Requires HIP")
489+
def test_gluon_kernel(target):
498490
with tempfile.TemporaryDirectory() as tmp_dir:
499491
dtype = "fp16"
500492
BM, BN, BK = 16, 16, 16
501-
493+
gluon_kernel_src = get_gluon_kernel_src(target.warp_size)
502494
kernel_path = write_triton_kernels(tmp_dir, gluon_kernel_src, kernel_utils_src)
503-
compile_aot_kernel_no_specialization(tmp_dir, kernel_path, dtype, BM, BN, BK)
495+
compile_aot_kernel_no_specialization(tmp_dir, kernel_path, dtype, BM, BN, BK, target=target)
504496
check_hasco_binary_str(tmp_dir, dtype)

0 commit comments

Comments
 (0)