Skip to content

Commit faa5033

Browse files
[Tools] Add experimental translator from Triton to Gluon (#8417)
This is not meant for production but allows converting a Triton kernel to a naive Gluon version. Of course the Gluon version would be significantly slower. --------- Co-authored-by: peterbell10 <[email protected]>
1 parent 1c2c074 commit faa5033

File tree

4 files changed

+987
-3
lines changed

4 files changed

+987
-3
lines changed
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
import sys
2+
import importlib.util
3+
import torch
4+
import triton
5+
import triton.language as tl
6+
import pytest
7+
from triton.tools.tensor_descriptor import TensorDescriptor
8+
9+
from triton.tools.triton_to_gluon_translater.translator import convert_triton_to_gluon
10+
from triton.tools.triton_to_gluon_translater.translator_helpers import convert_host_descriptor
11+
from triton._internal_testing import (
12+
is_blackwell, )
13+
14+
15+
def convert_kernel(kernel, kernel_name, tmp_path):
16+
converted = convert_triton_to_gluon(kernel)
17+
18+
# Write converted kernel to a file so @gluon.jit can retrieve source
19+
mod_path = tmp_path / "converted_kernel.py"
20+
mod_path.write_text(converted)
21+
22+
spec = importlib.util.spec_from_file_location("converted_kernel", mod_path)
23+
module = importlib.util.module_from_spec(spec)
24+
sys.modules["converted_kernel"] = module
25+
assert spec.loader is not None
26+
spec.loader.exec_module(module)
27+
kernel = getattr(module, kernel_name)
28+
return kernel
29+
30+
31+
@triton.jit
32+
def add_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr):
33+
pid = tl.program_id(0)
34+
offsets = pid * BLOCK + tl.arange(0, BLOCK)
35+
x = tl.load(x_ptr + offsets)
36+
y = tl.load(y_ptr + offsets)
37+
tl.store(out_ptr + offsets, x + y)
38+
39+
40+
@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell")
41+
def test_simple_kernel(tmp_path):
42+
kernel = convert_kernel(add_kernel, "add_kernel", tmp_path)
43+
44+
n = 1024
45+
BLOCK = 128
46+
x = torch.randn(n, device="cuda", dtype=torch.float32)
47+
y = torch.randn(n, device="cuda", dtype=torch.float32)
48+
out = torch.empty_like(x)
49+
grid = (n // BLOCK, )
50+
kernel[grid](x, y, out, n, BLOCK)
51+
52+
ref = torch.empty_like(x)
53+
add_kernel[grid](x, y, ref, n, BLOCK)
54+
55+
torch.testing.assert_close(out, ref)
56+
57+
58+
@triton.jit
59+
def impl_matmul_tile_kernel(a_ptr, b_ptr, c_ptr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr):
60+
offs_m = tl.arange(0, M)[:, None]
61+
offs_n = tl.arange(0, N)[None, :]
62+
acc = tl.zeros((M, N), dtype=tl.float32)
63+
a = tl.load(a_ptr + offs_m * K + (tl.arange(0, K))[None, :])
64+
b = tl.load(b_ptr + (tl.arange(0, K))[:, None] * N + offs_n)
65+
acc += tl.dot(a, b)
66+
tl.store(c_ptr + offs_m * N + offs_n, acc)
67+
68+
69+
@triton.jit
70+
def matmul_tile_kernel(a_ptr, b_ptr, c_ptr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
71+
impl_matmul_tile_kernel(a_ptr, b_ptr, c_ptr, BLOCK_M, BLOCK_N, BLOCK_K)
72+
73+
74+
@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell")
75+
def test_triton_to_gluon_dot_minimal(tmp_path):
76+
# Convert directly from the Triton kernel object
77+
kernel = convert_kernel(matmul_tile_kernel, "matmul_tile_kernel", tmp_path)
78+
M, N, K = 128, 128, 128
79+
a = torch.randn((M, K), device="cuda", dtype=torch.float16)
80+
b = torch.randn((K, N), device="cuda", dtype=torch.float16)
81+
grid = (1, )
82+
83+
c = torch.empty((M, N), device="cuda", dtype=torch.float32)
84+
kernel[grid](a, b, c, M, N, K, num_warps=8)
85+
86+
ref = torch.empty_like(c)
87+
matmul_tile_kernel[grid](a, b, ref, M, N, K, num_warps=8)
88+
torch.testing.assert_close(c, ref)
89+
90+
91+
@triton.jit
92+
def matmul_kernel( #
93+
a_ptr,
94+
b_ptr,
95+
output_ptr, #
96+
M,
97+
N,
98+
K, #
99+
stride_am,
100+
stride_ak, #
101+
stride_bk,
102+
stride_bn, #
103+
stride_cm,
104+
stride_cn, #
105+
BLOCK_M: tl.constexpr,
106+
BLOCK_N: tl.constexpr,
107+
BLOCK_K: tl.constexpr,
108+
):
109+
pid = tl.program_id(axis=0)
110+
num_pid_m = tl.cdiv(M, BLOCK_M)
111+
pid_m = pid % num_pid_m
112+
pid_n = pid // num_pid_m
113+
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
114+
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
115+
offs_k = tl.arange(0, BLOCK_K)
116+
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
117+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
118+
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=output_ptr.dtype.element_ty)
119+
for k in tl.range(0, tl.cdiv(K, BLOCK_K), step=1, num_stages=4):
120+
a = tl.load(a_ptrs)
121+
b = tl.load(b_ptrs)
122+
accumulator = tl.dot(a, b, acc=accumulator, out_dtype=output_ptr.dtype.element_ty)
123+
a_ptrs += BLOCK_K * stride_ak
124+
b_ptrs += BLOCK_K * stride_bk
125+
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
126+
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
127+
output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
128+
tl.store(output_ptrs, accumulator)
129+
130+
131+
@pytest.mark.parametrize("dtype_src_str", ["float16"])
132+
@pytest.mark.parametrize("dtype_dst_str", ["float32"])
133+
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES", [(128, 128, 64, 1)])
134+
@pytest.mark.parametrize("NUM_WARPS", [4])
135+
@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell")
136+
def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, NUM_WARPS, tmp_path):
137+
device = "cuda"
138+
M, N, K = 1024, 512, 256
139+
torch.manual_seed(42)
140+
dtype_src_str = "float32" if dtype_src_str == "tensorfloat32" else dtype_src_str
141+
dtype_src = getattr(torch, dtype_src_str)
142+
143+
kernel = convert_kernel(matmul_kernel, "matmul_kernel", tmp_path)
144+
145+
a = torch.randn(M, K, dtype=dtype_src, device=device)
146+
b = torch.randn(K, N, dtype=dtype_src, device=device)
147+
dtype_dst = getattr(torch, dtype_dst_str)
148+
output = torch.empty((M, N), dtype=dtype_dst, device=device)
149+
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
150+
kernel[grid](a, b, output, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), output.stride(0),
151+
output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K)
152+
153+
ref = torch.empty_like(output)
154+
matmul_kernel[grid](a, b, ref, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), output.stride(0),
155+
output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K)
156+
torch.testing.assert_close(output, ref)
157+
158+
159+
@triton.jit
160+
def descriptor_store_kernel(desc, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, V: tl.constexpr):
161+
tile = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float16) + V
162+
desc.store([0, 0], tile)
163+
164+
165+
@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell")
166+
def test_triton_to_gluon_descriptor_roundtrip(tmp_path):
167+
kernel = convert_kernel(descriptor_store_kernel, "descriptor_store_kernel", tmp_path)
168+
169+
M = N = 64
170+
y = torch.zeros((M, N), device="cuda", dtype=torch.float16)
171+
grid = (1, )
172+
block_shape = [M, N]
173+
desc = TensorDescriptor(y, y.shape, y.stride(), block_shape)
174+
gluon_desc = convert_host_descriptor(desc)
175+
kernel[grid](gluon_desc, M, N, 1.0)
176+
177+
y_ref = torch.zeros((M, N), device="cuda", dtype=torch.float16)
178+
desc_ref = TensorDescriptor(y_ref, y_ref.shape, y_ref.stride(), block_shape)
179+
descriptor_store_kernel[grid](desc_ref, M, N, 1.0)
180+
torch.testing.assert_close(y, y_ref)
181+
182+
183+
@triton.jit
184+
def descriptor_copy_kernel(in_desc, out_desc, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
185+
tile = in_desc.load([0, 0])
186+
out_desc.store([0, 0], tile)
187+
188+
189+
@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell")
190+
def test_triton_to_gluon_descriptor_load_roundtrip(tmp_path):
191+
kernel = convert_kernel(descriptor_copy_kernel, "descriptor_copy_kernel", tmp_path)
192+
193+
M = N = 64
194+
x = torch.ones((M, N), device="cuda", dtype=torch.float16) * 3.0
195+
y = torch.zeros((M, N), device="cuda", dtype=torch.float16)
196+
grid = (1, )
197+
block_shape = [M, N]
198+
199+
in_desc = TensorDescriptor(x, x.shape, x.stride(), block_shape)
200+
gluon_desc = convert_host_descriptor(in_desc)
201+
out_desc = convert_host_descriptor(TensorDescriptor(y, y.shape, y.stride(), block_shape))
202+
kernel[grid](gluon_desc, out_desc, M, N)
203+
204+
y_ref = torch.zeros((M, N), device="cuda", dtype=torch.float16)
205+
desc_ref = TensorDescriptor(y_ref, y_ref.shape, y_ref.stride(), block_shape)
206+
descriptor_copy_kernel[grid](in_desc, desc_ref, M, N)
207+
torch.testing.assert_close(y, y_ref)
208+
209+
210+
@triton.jit
211+
def reshape_trans_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr):
212+
pid = tl.program_id(0)
213+
offsets = pid * BLOCK + tl.arange(0, BLOCK)
214+
215+
x = tl.reshape(tl.load(x_ptr + offsets), 16, 16)
216+
y = tl.load(y_ptr + offsets).reshape(16, 16)
217+
a = x + y.trans(1, 0)
218+
a = a.reshape(256)
219+
tl.store(out_ptr + offsets, a)
220+
221+
222+
@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell")
223+
def test_triton_reshape_trans(tmp_path):
224+
kernel = convert_kernel(reshape_trans_kernel, "reshape_trans_kernel", tmp_path)
225+
226+
n = 1024
227+
BLOCK = 256
228+
x = torch.randn(n, device="cuda", dtype=torch.float32)
229+
y = torch.randn(n, device="cuda", dtype=torch.float32)
230+
out = torch.empty_like(x)
231+
grid = (n // BLOCK, )
232+
kernel[grid](x, y, out, n, BLOCK)
233+
ref = torch.empty_like(x)
234+
reshape_trans_kernel[grid](x, y, ref, n, BLOCK)
235+
torch.testing.assert_close(out, ref)
236+
237+
238+
BLOCK_SPLIT = tl.constexpr(256)
239+
240+
241+
@triton.jit
242+
def split_kernel(x_ptr, out_ptr):
243+
pid = tl.program_id(0)
244+
offsets = pid * BLOCK_SPLIT + tl.arange(0, BLOCK_SPLIT)
245+
offsets2 = pid * BLOCK_SPLIT + tl.arange(0, 2 * BLOCK_SPLIT)
246+
247+
s0, s1 = tl.reshape(tl.load(x_ptr + offsets2), BLOCK_SPLIT, 2).split()
248+
a = s0 + s1
249+
p = out_ptr + offsets
250+
tl.store(p, a)
251+
252+
253+
@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell")
254+
def test_split(tmp_path):
255+
kernel = convert_kernel(split_kernel, "split_kernel", tmp_path)
256+
257+
n = 1024
258+
x = torch.randn(2 * n, device="cuda", dtype=torch.float32)
259+
grid = (n // BLOCK_SPLIT, )
260+
261+
out = torch.empty_like(x[:n])
262+
kernel[grid](x, out)
263+
ref = torch.empty_like(x[:n])
264+
split_kernel[grid](x, ref)
265+
torch.testing.assert_close(out, ref)
266+
267+
268+
@triton.jit
269+
def reduce_to_scalar_kernel(out_ptr):
270+
x = tl.arange(0, 16)
271+
x = tl.sum(x)
272+
tl.store(out_ptr, x)
273+
274+
275+
@pytest.mark.skipif(not (is_blackwell()), reason="Requires Blackwell")
276+
def test_reduce_to_scalar(tmp_path):
277+
kernel = convert_kernel(reduce_to_scalar_kernel, "reduce_to_scalar_kernel", tmp_path)
278+
grid = (1, )
279+
280+
out = torch.empty((1, ), device="cuda", dtype=torch.int32)
281+
kernel[grid](out)
282+
ref = torch.empty_like(out)
283+
reduce_to_scalar_kernel[grid](ref)
284+
torch.testing.assert_close(out, ref)

0 commit comments

Comments
 (0)