|
| 1 | +import sys |
| 2 | +import importlib.util |
| 3 | +import torch |
| 4 | +import triton |
| 5 | +import triton.language as tl |
| 6 | +import pytest |
| 7 | +from triton.tools.tensor_descriptor import TensorDescriptor |
| 8 | + |
| 9 | +from triton.tools.triton_to_gluon_translater.translator import convert_triton_to_gluon |
| 10 | +from triton.tools.triton_to_gluon_translater.translator_helpers import convert_host_descriptor |
| 11 | +from triton._internal_testing import ( |
| 12 | + is_blackwell, ) |
| 13 | + |
| 14 | + |
| 15 | +def convert_kernel(kernel, kernel_name, tmp_path): |
| 16 | + converted = convert_triton_to_gluon(kernel) |
| 17 | + |
| 18 | + # Write converted kernel to a file so @gluon.jit can retrieve source |
| 19 | + mod_path = tmp_path / "converted_kernel.py" |
| 20 | + mod_path.write_text(converted) |
| 21 | + |
| 22 | + spec = importlib.util.spec_from_file_location("converted_kernel", mod_path) |
| 23 | + module = importlib.util.module_from_spec(spec) |
| 24 | + sys.modules["converted_kernel"] = module |
| 25 | + assert spec.loader is not None |
| 26 | + spec.loader.exec_module(module) |
| 27 | + kernel = getattr(module, kernel_name) |
| 28 | + return kernel |
| 29 | + |
| 30 | + |
| 31 | +@triton.jit |
| 32 | +def add_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr): |
| 33 | + pid = tl.program_id(0) |
| 34 | + offsets = pid * BLOCK + tl.arange(0, BLOCK) |
| 35 | + x = tl.load(x_ptr + offsets) |
| 36 | + y = tl.load(y_ptr + offsets) |
| 37 | + tl.store(out_ptr + offsets, x + y) |
| 38 | + |
| 39 | + |
| 40 | +@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell") |
| 41 | +def test_simple_kernel(tmp_path): |
| 42 | + kernel = convert_kernel(add_kernel, "add_kernel", tmp_path) |
| 43 | + |
| 44 | + n = 1024 |
| 45 | + BLOCK = 128 |
| 46 | + x = torch.randn(n, device="cuda", dtype=torch.float32) |
| 47 | + y = torch.randn(n, device="cuda", dtype=torch.float32) |
| 48 | + out = torch.empty_like(x) |
| 49 | + grid = (n // BLOCK, ) |
| 50 | + kernel[grid](x, y, out, n, BLOCK) |
| 51 | + |
| 52 | + ref = torch.empty_like(x) |
| 53 | + add_kernel[grid](x, y, ref, n, BLOCK) |
| 54 | + |
| 55 | + torch.testing.assert_close(out, ref) |
| 56 | + |
| 57 | + |
| 58 | +@triton.jit |
| 59 | +def impl_matmul_tile_kernel(a_ptr, b_ptr, c_ptr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr): |
| 60 | + offs_m = tl.arange(0, M)[:, None] |
| 61 | + offs_n = tl.arange(0, N)[None, :] |
| 62 | + acc = tl.zeros((M, N), dtype=tl.float32) |
| 63 | + a = tl.load(a_ptr + offs_m * K + (tl.arange(0, K))[None, :]) |
| 64 | + b = tl.load(b_ptr + (tl.arange(0, K))[:, None] * N + offs_n) |
| 65 | + acc += tl.dot(a, b) |
| 66 | + tl.store(c_ptr + offs_m * N + offs_n, acc) |
| 67 | + |
| 68 | + |
| 69 | +@triton.jit |
| 70 | +def matmul_tile_kernel(a_ptr, b_ptr, c_ptr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr): |
| 71 | + impl_matmul_tile_kernel(a_ptr, b_ptr, c_ptr, BLOCK_M, BLOCK_N, BLOCK_K) |
| 72 | + |
| 73 | + |
| 74 | +@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell") |
| 75 | +def test_triton_to_gluon_dot_minimal(tmp_path): |
| 76 | + # Convert directly from the Triton kernel object |
| 77 | + kernel = convert_kernel(matmul_tile_kernel, "matmul_tile_kernel", tmp_path) |
| 78 | + M, N, K = 128, 128, 128 |
| 79 | + a = torch.randn((M, K), device="cuda", dtype=torch.float16) |
| 80 | + b = torch.randn((K, N), device="cuda", dtype=torch.float16) |
| 81 | + grid = (1, ) |
| 82 | + |
| 83 | + c = torch.empty((M, N), device="cuda", dtype=torch.float32) |
| 84 | + kernel[grid](a, b, c, M, N, K, num_warps=8) |
| 85 | + |
| 86 | + ref = torch.empty_like(c) |
| 87 | + matmul_tile_kernel[grid](a, b, ref, M, N, K, num_warps=8) |
| 88 | + torch.testing.assert_close(c, ref) |
| 89 | + |
| 90 | + |
| 91 | +@triton.jit |
| 92 | +def matmul_kernel( # |
| 93 | + a_ptr, |
| 94 | + b_ptr, |
| 95 | + output_ptr, # |
| 96 | + M, |
| 97 | + N, |
| 98 | + K, # |
| 99 | + stride_am, |
| 100 | + stride_ak, # |
| 101 | + stride_bk, |
| 102 | + stride_bn, # |
| 103 | + stride_cm, |
| 104 | + stride_cn, # |
| 105 | + BLOCK_M: tl.constexpr, |
| 106 | + BLOCK_N: tl.constexpr, |
| 107 | + BLOCK_K: tl.constexpr, |
| 108 | +): |
| 109 | + pid = tl.program_id(axis=0) |
| 110 | + num_pid_m = tl.cdiv(M, BLOCK_M) |
| 111 | + pid_m = pid % num_pid_m |
| 112 | + pid_n = pid // num_pid_m |
| 113 | + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M |
| 114 | + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N |
| 115 | + offs_k = tl.arange(0, BLOCK_K) |
| 116 | + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) |
| 117 | + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) |
| 118 | + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=output_ptr.dtype.element_ty) |
| 119 | + for k in tl.range(0, tl.cdiv(K, BLOCK_K), step=1, num_stages=4): |
| 120 | + a = tl.load(a_ptrs) |
| 121 | + b = tl.load(b_ptrs) |
| 122 | + accumulator = tl.dot(a, b, acc=accumulator, out_dtype=output_ptr.dtype.element_ty) |
| 123 | + a_ptrs += BLOCK_K * stride_ak |
| 124 | + b_ptrs += BLOCK_K * stride_bk |
| 125 | + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| 126 | + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| 127 | + output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] |
| 128 | + tl.store(output_ptrs, accumulator) |
| 129 | + |
| 130 | + |
| 131 | +@pytest.mark.parametrize("dtype_src_str", ["float16"]) |
| 132 | +@pytest.mark.parametrize("dtype_dst_str", ["float32"]) |
| 133 | +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES", [(128, 128, 64, 1)]) |
| 134 | +@pytest.mark.parametrize("NUM_WARPS", [4]) |
| 135 | +@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell") |
| 136 | +def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, NUM_WARPS, tmp_path): |
| 137 | + device = "cuda" |
| 138 | + M, N, K = 1024, 512, 256 |
| 139 | + torch.manual_seed(42) |
| 140 | + dtype_src_str = "float32" if dtype_src_str == "tensorfloat32" else dtype_src_str |
| 141 | + dtype_src = getattr(torch, dtype_src_str) |
| 142 | + |
| 143 | + kernel = convert_kernel(matmul_kernel, "matmul_kernel", tmp_path) |
| 144 | + |
| 145 | + a = torch.randn(M, K, dtype=dtype_src, device=device) |
| 146 | + b = torch.randn(K, N, dtype=dtype_src, device=device) |
| 147 | + dtype_dst = getattr(torch, dtype_dst_str) |
| 148 | + output = torch.empty((M, N), dtype=dtype_dst, device=device) |
| 149 | + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) |
| 150 | + kernel[grid](a, b, output, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), output.stride(0), |
| 151 | + output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K) |
| 152 | + |
| 153 | + ref = torch.empty_like(output) |
| 154 | + matmul_kernel[grid](a, b, ref, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), output.stride(0), |
| 155 | + output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K) |
| 156 | + torch.testing.assert_close(output, ref) |
| 157 | + |
| 158 | + |
| 159 | +@triton.jit |
| 160 | +def descriptor_store_kernel(desc, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, V: tl.constexpr): |
| 161 | + tile = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float16) + V |
| 162 | + desc.store([0, 0], tile) |
| 163 | + |
| 164 | + |
| 165 | +@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell") |
| 166 | +def test_triton_to_gluon_descriptor_roundtrip(tmp_path): |
| 167 | + kernel = convert_kernel(descriptor_store_kernel, "descriptor_store_kernel", tmp_path) |
| 168 | + |
| 169 | + M = N = 64 |
| 170 | + y = torch.zeros((M, N), device="cuda", dtype=torch.float16) |
| 171 | + grid = (1, ) |
| 172 | + block_shape = [M, N] |
| 173 | + desc = TensorDescriptor(y, y.shape, y.stride(), block_shape) |
| 174 | + gluon_desc = convert_host_descriptor(desc) |
| 175 | + kernel[grid](gluon_desc, M, N, 1.0) |
| 176 | + |
| 177 | + y_ref = torch.zeros((M, N), device="cuda", dtype=torch.float16) |
| 178 | + desc_ref = TensorDescriptor(y_ref, y_ref.shape, y_ref.stride(), block_shape) |
| 179 | + descriptor_store_kernel[grid](desc_ref, M, N, 1.0) |
| 180 | + torch.testing.assert_close(y, y_ref) |
| 181 | + |
| 182 | + |
| 183 | +@triton.jit |
| 184 | +def descriptor_copy_kernel(in_desc, out_desc, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): |
| 185 | + tile = in_desc.load([0, 0]) |
| 186 | + out_desc.store([0, 0], tile) |
| 187 | + |
| 188 | + |
| 189 | +@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell") |
| 190 | +def test_triton_to_gluon_descriptor_load_roundtrip(tmp_path): |
| 191 | + kernel = convert_kernel(descriptor_copy_kernel, "descriptor_copy_kernel", tmp_path) |
| 192 | + |
| 193 | + M = N = 64 |
| 194 | + x = torch.ones((M, N), device="cuda", dtype=torch.float16) * 3.0 |
| 195 | + y = torch.zeros((M, N), device="cuda", dtype=torch.float16) |
| 196 | + grid = (1, ) |
| 197 | + block_shape = [M, N] |
| 198 | + |
| 199 | + in_desc = TensorDescriptor(x, x.shape, x.stride(), block_shape) |
| 200 | + gluon_desc = convert_host_descriptor(in_desc) |
| 201 | + out_desc = convert_host_descriptor(TensorDescriptor(y, y.shape, y.stride(), block_shape)) |
| 202 | + kernel[grid](gluon_desc, out_desc, M, N) |
| 203 | + |
| 204 | + y_ref = torch.zeros((M, N), device="cuda", dtype=torch.float16) |
| 205 | + desc_ref = TensorDescriptor(y_ref, y_ref.shape, y_ref.stride(), block_shape) |
| 206 | + descriptor_copy_kernel[grid](in_desc, desc_ref, M, N) |
| 207 | + torch.testing.assert_close(y, y_ref) |
| 208 | + |
| 209 | + |
| 210 | +@triton.jit |
| 211 | +def reshape_trans_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr): |
| 212 | + pid = tl.program_id(0) |
| 213 | + offsets = pid * BLOCK + tl.arange(0, BLOCK) |
| 214 | + |
| 215 | + x = tl.reshape(tl.load(x_ptr + offsets), 16, 16) |
| 216 | + y = tl.load(y_ptr + offsets).reshape(16, 16) |
| 217 | + a = x + y.trans(1, 0) |
| 218 | + a = a.reshape(256) |
| 219 | + tl.store(out_ptr + offsets, a) |
| 220 | + |
| 221 | + |
| 222 | +@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell") |
| 223 | +def test_triton_reshape_trans(tmp_path): |
| 224 | + kernel = convert_kernel(reshape_trans_kernel, "reshape_trans_kernel", tmp_path) |
| 225 | + |
| 226 | + n = 1024 |
| 227 | + BLOCK = 256 |
| 228 | + x = torch.randn(n, device="cuda", dtype=torch.float32) |
| 229 | + y = torch.randn(n, device="cuda", dtype=torch.float32) |
| 230 | + out = torch.empty_like(x) |
| 231 | + grid = (n // BLOCK, ) |
| 232 | + kernel[grid](x, y, out, n, BLOCK) |
| 233 | + ref = torch.empty_like(x) |
| 234 | + reshape_trans_kernel[grid](x, y, ref, n, BLOCK) |
| 235 | + torch.testing.assert_close(out, ref) |
| 236 | + |
| 237 | + |
| 238 | +BLOCK_SPLIT = tl.constexpr(256) |
| 239 | + |
| 240 | + |
| 241 | +@triton.jit |
| 242 | +def split_kernel(x_ptr, out_ptr): |
| 243 | + pid = tl.program_id(0) |
| 244 | + offsets = pid * BLOCK_SPLIT + tl.arange(0, BLOCK_SPLIT) |
| 245 | + offsets2 = pid * BLOCK_SPLIT + tl.arange(0, 2 * BLOCK_SPLIT) |
| 246 | + |
| 247 | + s0, s1 = tl.reshape(tl.load(x_ptr + offsets2), BLOCK_SPLIT, 2).split() |
| 248 | + a = s0 + s1 |
| 249 | + p = out_ptr + offsets |
| 250 | + tl.store(p, a) |
| 251 | + |
| 252 | + |
| 253 | +@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell") |
| 254 | +def test_split(tmp_path): |
| 255 | + kernel = convert_kernel(split_kernel, "split_kernel", tmp_path) |
| 256 | + |
| 257 | + n = 1024 |
| 258 | + x = torch.randn(2 * n, device="cuda", dtype=torch.float32) |
| 259 | + grid = (n // BLOCK_SPLIT, ) |
| 260 | + |
| 261 | + out = torch.empty_like(x[:n]) |
| 262 | + kernel[grid](x, out) |
| 263 | + ref = torch.empty_like(x[:n]) |
| 264 | + split_kernel[grid](x, ref) |
| 265 | + torch.testing.assert_close(out, ref) |
| 266 | + |
| 267 | + |
| 268 | +@triton.jit |
| 269 | +def reduce_to_scalar_kernel(out_ptr): |
| 270 | + x = tl.arange(0, 16) |
| 271 | + x = tl.sum(x) |
| 272 | + tl.store(out_ptr, x) |
| 273 | + |
| 274 | + |
| 275 | +@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell") |
| 276 | +def test_reduce_to_scalar(tmp_path): |
| 277 | + kernel = convert_kernel(reduce_to_scalar_kernel, "reduce_to_scalar_kernel", tmp_path) |
| 278 | + grid = (1, ) |
| 279 | + |
| 280 | + out = torch.empty((1, ), device="cuda", dtype=torch.int32) |
| 281 | + kernel[grid](out) |
| 282 | + ref = torch.empty_like(out) |
| 283 | + reduce_to_scalar_kernel[grid](ref) |
| 284 | + torch.testing.assert_close(out, ref) |
0 commit comments