@@ -63,7 +63,9 @@ def kernel(C, A, B, M, N, K,
6363 tl.store(c_ptrs, c)
6464"""
6565
66- gluon_kernel_src = """
66+
67+ def get_gluon_kernel_src (threads_per_warp ):
68+ return f"""
6769from triton.experimental import gluon
6870from triton.experimental.gluon import language as gl
6971
@@ -77,12 +79,13 @@ def kernel(
7779 BLOCK_N: gl.constexpr,
7880 BLOCK_K: gl.constexpr
7981):
80- layout: gl.constexpr = gl.BlockedLayout(size_per_thread=[1], threads_per_warp=[64 ], warps_per_cta=[1], order=[0])
82+ layout: gl.constexpr = gl.BlockedLayout(size_per_thread=[1], threads_per_warp=[{ threads_per_warp } ], warps_per_cta=[1], order=[0])
8183 offs = gl.arange(0, 64, layout=layout)
8284 a = gl.load(A + offs)
8385 gl.store(B + offs, a)
8486"""
8587
88+
8689test_utils_src = """
8790#include <cuda.h>
8891#include <stdio.h>
@@ -215,34 +218,21 @@ def write_triton_kernels(dir, src, util_src):
215218 return kernel_path
216219
217220
218- def _compile_kernel (dir , signature , kernel_name , out_name , out_path , num_warps , grid , kernel_path ):
221+ def _compile_kernel (dir , signature , kernel_name , out_name , out_path , num_warps , grid , kernel_path , target = None ):
219222 compiler_path = os .path .join (triton .tools .__path__ [0 ], "compile.py" )
220-
221- subprocess .run (
222- [
223- sys .executable ,
224- compiler_path ,
225- "-n" ,
226- kernel_name ,
227- "--signature" ,
228- signature ,
229- "--out-name" ,
230- out_name ,
231- "-o" ,
232- out_path ,
233- "-w" ,
234- str (num_warps ),
235- "-g" ,
236- grid ,
237- kernel_path ,
238- ],
239- check = True ,
240- cwd = dir ,
241- )
223+ cmd_args = [
224+ sys .executable , compiler_path , "-n" , kernel_name , "--signature" , signature , "--out-name" , out_name , "-o" ,
225+ out_path , "-w" ,
226+ str (num_warps ), "-g" , grid
227+ ]
228+ if target :
229+ cmd_args .extend (["-t" , "%s:%s:%i" % (target .backend , target .arch , target .warp_size )])
230+ cmd_args .append (kernel_path )
231+ subprocess .run (cmd_args , check = True , cwd = dir )
242232
243233
244234# Edge case kernel with no specialization
245- def compile_aot_kernel_no_specialization (dir , kernel_path , dtype , BM , BN , BK ):
235+ def compile_aot_kernel_no_specialization (dir , kernel_path , dtype , BM , BN , BK , target = None ):
246236 # compile all desired configs
247237 sig = f"*fp32, *{ dtype } , *{ dtype } , i32, i32, i32, i32, i32, i32, i32, i32, i32, { BM } , { BN } , { BK } "
248238 name = f"matmul_{ dtype } "
@@ -256,10 +246,11 @@ def compile_aot_kernel_no_specialization(dir, kernel_path, dtype, BM, BN, BK):
256246 num_warps = 1 ,
257247 grid = grid ,
258248 kernel_path = kernel_path ,
249+ target = target ,
259250 )
260251
261252
262- def compile_aot_kernels (dir , kernel_path , dtype , BM , BN , BK , ha_hb_hints ):
253+ def compile_aot_kernels (dir , kernel_path , dtype , BM , BN , BK , ha_hb_hints , target = None ):
263254 # compile all desired configs
264255 for ha , hb in ha_hb_hints :
265256 sig = f"*fp32:16, *{ dtype } :16, *{ dtype } :16, i32, i32, i32, i32{ ha } , i32:1, i32{ hb } , i32:1, i32:16, i32:1, { BM } , { BN } , { BK } "
@@ -274,6 +265,7 @@ def compile_aot_kernels(dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints):
274265 num_warps = 1 ,
275266 grid = grid ,
276267 kernel_path = kernel_path ,
268+ target = target ,
277269 )
278270
279271
@@ -492,13 +484,13 @@ def test_ttgir_to_asm():
492484 assert '.wavefront_size: 64' in amdgcn
493485
494486
495- def test_gluon_kernel ():
496- if not is_hip ():
497- pytest . skip ( "Gluon kernel is only supported on HIP" )
487+ @ pytest . mark . parametrize ( "target" , [ GPUTarget ( "hip" , "gfx942" , 64 ), GPUTarget ( "hip" , "gfx1250" , 32 )])
488+ @ pytest . mark . skipif ( not is_hip (), reason = "Requires HIP" )
489+ def test_gluon_kernel ( target ):
498490 with tempfile .TemporaryDirectory () as tmp_dir :
499491 dtype = "fp16"
500492 BM , BN , BK = 16 , 16 , 16
501-
493+ gluon_kernel_src = get_gluon_kernel_src ( target . warp_size )
502494 kernel_path = write_triton_kernels (tmp_dir , gluon_kernel_src , kernel_utils_src )
503- compile_aot_kernel_no_specialization (tmp_dir , kernel_path , dtype , BM , BN , BK )
495+ compile_aot_kernel_no_specialization (tmp_dir , kernel_path , dtype , BM , BN , BK , target = target )
504496 check_hasco_binary_str (tmp_dir , dtype )
0 commit comments