Replies: 1 comment
-
I write a test script as follow: import torch
import triton
import triton.language as tl
@triton.jit
def gemm(a, b, o, m: tl.constexpr, n: tl.constexpr, k: tl.constexpr, BLOCK_M: tl.constexpr):
pid = tl.program_id(0)
start_m = tl.program_id(1)
off_a = pid * m * k
a_ptr = tl.make_block_ptr(
base=a + off_a,
shape=(m, k),
offsets=(start_m * BLOCK_M, 0),
strides=(k, 1),
block_shape=(BLOCK_M, k),
order=(1, 0),
)
off_b = pid * k * n
b_ptr = tl.make_block_ptr(
base=b + off_b,
shape=(k, n),
offsets=(0, 0),
strides=(1, k),
block_shape=(k, BLOCK_M),
order=(0, 1),
)
off_o = pid * m * n
o_ptr = tl.make_block_ptr(
base=o + off_o,
shape=(m, n),
offsets=(start_m * BLOCK_M, 0),
strides=(n, 1),
block_shape=(BLOCK_M, BLOCK_M),
order=(1, 0),
)
a_tensor = tl.load(a_ptr)
for _ in range(tl.cdiv(n, BLOCK_M)):
b_tensor = tl.load(b_ptr)
o_tensor = tl.dot(a_tensor, b_tensor)
tl.store(o_ptr, o_tensor.to(tl.float16))
b_ptr = tl.advance(b_ptr, (0, BLOCK_M))
o_ptr = tl.advance(o_ptr, (0, BLOCK_M))
if __name__ == "__main__":
a = torch.randn((5, 1024, 32), dtype=torch.float16, device="cuda")
b = torch.randn((5, 1024, 32), dtype=torch.float16, device="cuda")
o = torch.zeros((5, 1024, 1024), dtype=torch.float16, device="cuda")
grid = lambda META: (5, triton.cdiv(1024, META["BLOCK_M"]), 1)
gemm[grid](a, b, o, 1024, 1024, 32, 32)
o_pt = torch.matmul(a, b.transpose(1, 2))
torch.testing.assert_close(o, o_pt, atol=5e-3, rtol=5e-3) and then I run with: MLIR_ENABLE_DUMP=1 python3 test.py 2>tmp.log and get mlir output: tmp.log after Am I right? whether (0, 1) or (1, 0) make no effect to the final kernel launch. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
As asked in the title. I don't know How does the order parameter in make_block_ptr work.
I have found the
lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
andlib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp
passes but they seemly don't precess the order attribute.where does the triton compiler use the
order
paramter ofmake_block_ptr
?looking forward for someone help.
thanks.
BTW, document
Beta Was this translation helpful? Give feedback.
All reactions