-
Notifications
You must be signed in to change notification settings - Fork 452
Description
Required prerequisites
- I have read the documentation https://tilelang.com.
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of TileLang are you using?
0.1.7.post3+cuda.gitc47d6df3
System information
import sys, tilelang, torch
print(sys.version, sys.platform)
3.12.0 | packaged by Anaconda, Inc. | (main, Oct 2 2023, 17:29:18) [GCC 11.2.0] linux
print(tilelang.version)
0.1.7.post3+cuda.gitc47d6df3
print(torch.version)
2.10.0+cu128
The machine is B200 provided by @Hamerlate
Problem description
The LoopUnswitching pass seems to cause incorrect codegen. Furthermore, the result returns to normal when this pass is disabled. See the explanation below for more info.
Reproducible example code
Note that this is reported from pr #1774.
import torch
import tilelang
import tilelang.language as T
from tilelang.carver.arch import driver
from tilelang.profiler import do_bench
tilelang.disable_cache()
@tilelang.jit
def gemm(
A,
B,
block_M,
block_N,
store_block_N, # block_N for C_shared
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
use_tma_store=True,
):
M, N, K = T.const("M, N, K")
A: T.Tensor[[M, K], in_dtype]
B: T.Tensor[[K, N], in_dtype]
C = T.empty((M, N), out_dtype)
sm_num = driver.get_num_sms()
m_blocks = T.ceildiv(M, block_M)
n_blocks = T.ceildiv(N, block_N)
assert K % (2 * block_K) == 0 # for simplicity
k_blocks = T.ceildiv(K, block_K)
waves = T.ceildiv(m_blocks * n_blocks, sm_num)
group_size = 8
with T.Kernel(sm_num, threads=256) as (block_id):
A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype)
B_shared = T.alloc_shared((num_stages, block_K, block_N), in_dtype)
C_tmem_0 = T.alloc_tmem([block_M, block_N], accum_dtype)
C_tmem_1 = T.alloc_tmem([block_M, block_N], accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype)
C_shared = T.alloc_shared((block_M, store_block_N), out_dtype)
loaded = T.alloc_barrier([32] * num_stages)
consumed = T.alloc_barrier([1] * num_stages)
tmem_full = T.alloc_barrier([1] * 2)
tmem_empty = T.alloc_barrier([128] * 2)
tx = T.get_thread_binding()
if tx < 32: # warp 0: issue tma
for w in T.unroll(waves):
tile_id = sm_num * w + block_id
bx = (tile_id // group_size) % m_blocks
by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size
if bx * block_M < M and by * block_N < N:
for k in T.serial(k_blocks):
T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1)
T.copy(
A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], A_shared[k % num_stages, :, :]
) # cannot use BufferLoad here
T.copy(B[k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared[k % num_stages, :, :])
T.mbarrier_arrive(loaded[k % num_stages])
elif tx < 64: # warp 1: issue tcgen5
for w in T.unroll(waves):
tile_id = sm_num * w + block_id
bx = (tile_id // group_size) % m_blocks
by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size
if bx * block_M < M and by * block_N < N:
T.mbarrier_wait_parity(tmem_empty[w & 1], ((w // 2) & 1) ^ 1)
for k in T.serial(k_blocks):
T.mbarrier_wait_parity(loaded[k % num_stages], (k // num_stages) & 1)
if w & 1 == 0:
T.gemm(
A_shared[k % num_stages, :, :],
B_shared[k % num_stages, :, :],
C_tmem_0,
False,
False,
mbar=consumed[k % num_stages],
wg_wait=-1,
clear_accum=k == 0,
)
else:
T.gemm(
A_shared[k % num_stages, :, :],
B_shared[k % num_stages, :, :],
C_tmem_1,
False,
False,
mbar=consumed[k % num_stages],
wg_wait=-1,
clear_accum=k == 0,
)
T.tcgen05_mma_arrive(tmem_full[w & 1])
elif 128 <= tx < 256: # warp 4~7: epilogue
for w in T.unroll(waves):
tile_id = sm_num * w + block_id
bx = (tile_id // group_size) % m_blocks
by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size
if bx * block_M < M and by * block_N < N:
T.mbarrier_wait_parity(tmem_full[w & 1], (w // 2) & 1)
T.sync_threads(1, 128)
if (w & 1) == 0:
T.copy(C_tmem_0, C_local)
else:
T.copy(C_tmem_1, C_local)
T.mbarrier_arrive(tmem_empty[w & 1])
if use_tma_store:
for i in T.unroll(T.ceildiv(block_N, store_block_N)):
T.copy(C_local[:, i * store_block_N : (i + 1) * store_block_N], C_shared)
T.copy(C_shared, C[bx * block_M, by * block_N + i * store_block_N])
else:
T.copy(C_local, C_local_cast)
T.copy(C_local_cast, C[bx * block_M, by * block_N])
return C
def main():
M, N, K = 8192, 8192, 8192
block_M, block_N, block_K = 128, 256, 64
store_block_N = 128
in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float
num_stages = 4
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
print(gemm.get_kernel_source(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages))
c = gemm(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages)
ref_c = (a.to(torch.float) @ b.to(torch.float)).to(torch.bfloat16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("All checks passed. ✅")
tl_latency = do_bench(
lambda: gemm(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages), backend="cupti"
)
torch_latency = do_bench(lambda: a @ b, backend="cupti")
print(f"Tilelang latency: {tl_latency} ms")
print(f"Flops: {2 * M * N * K / (tl_latency / 1e3) / 1e12} TFLOPS")
print(f"Torch latency: {torch_latency} ms")
print(f"Flops: {2 * M * N * K / (torch_latency / 1e3) / 1e12} TFLOPS")
if __name__ == "__main__":
main()
To successfully enable codegen, you need to modify tilelang src as shown below, and run on a sm100 machine:
diff --git a/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/tileop/gemm/gemm_tcgen05.py
index 32bef1cb..90debeb1 100644
--- a/tilelang/tileop/gemm/gemm_tcgen05.py
+++ b/tilelang/tileop/gemm/gemm_tcgen05.py
@@ -96,7 +96,8 @@ class GemmTCGEN5(GemmBase):
if mbar == 0:
raise ValueError("TCGEN5MMA requires a valid mbarrier")
- mbarptr = mbar.access_ptr("rw")
+ from tilelang.utils.language import retrieve_ptr
+ mbarptr = retrieve_ptr(mbar, "rw")
C_coords = self.C_coords
if len(C_coords) != 2:
Also, it is normal that the generated cuda code could not pass nvcc compilation. (Check out #1774 for more info)
Traceback
The generated cuda is sth that ends with:
} else {
if (128 <= ((int)threadIdx.x)) {
#pragma unroll
for (int w_2 = 0; w_2 < 14; ++w_2) {
if (((w_2 * 37) + (((int)blockIdx.x) >> 2)) < 512) {
tmem_full[(w_2 & 1)].wait(((w_2 >> 1) & 1));
tl::__sync_thread_partial<1, 128>();
if ((w_2 & 1) == 0) {
tl::tcgen05_ld_32dp32bNx<256, false>(C_tmem_0[0], 0, (&(C_local[0])));
} else {
tl::tcgen05_ld_32dp32bNx<256, false>(C_tmem_1[0], 0, (&(C_local[0])));
}
tmem_empty[(w_2 & 1)].arrive();
tl::__sync_thread_partial<3, 128>();
if (((int)threadIdx.x) == 128) {
#pragma unroll
for (int i_1 = 0; i_1 < 2; ++i_1) {
#pragma unroll
for (int i_2 = 0; i_2 < 16; ++i_2) {
for (int vec = 0; vec < 2; ++vec) {
uint2 __1;
float4 v_ = *(float4*)(C_local + (((i_1 * 128) + (i_2 * 8)) + (vec * 4)));
(reinterpret_cast<__nv_bfloat162*>(&__1))[0] = __float22bfloat162_rn(((float2*)(&v_))[0]);
(reinterpret_cast<__nv_bfloat162*>(&__1))[1] = __float22bfloat162_rn(((float2*)(&v_))[1]);
*(uint2*)(C_shared_local_cast + (vec * 4)) = __1;
}
*(uint4*)(((bfloat16_t*)buf_dyn_shmem) + (((((((i_2 >> 3) * 8192) + (((int)threadIdx.x) * 64)) + (((((i_2 & 7) >> 2) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 32)) + (((((i_2 & 3) >> 1) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 16)) + ((((i_2 & 1) + (((int)threadIdx.x) & 1)) & 1) * 8)) + 90112)) = *(uint4*)(C_shared_local_cast + 0);
}
tl::fence_proxy_async();
#pragma unroll
for (int i_3 = 0; i_3 < 2; ++i_3) {
tl::tma_store(C_desc, (&(((bfloat16_t*)buf_dyn_shmem)[((i_3 * 8192) + 98304)])), (((((((w_2 * 37) + (((int)blockIdx.x) >> 2)) >> 7) * 2048) + ((((w_2 * 148) + ((int)blockIdx.x)) & 7) * 256)) + (i_1 * 128)) + (i_3 * 64)), (((((w_2 * 37) + (((int)blockIdx.x) >> 2)) & 127) >> 1) * 128));
tl::tma_store_arrive();
tl::tma_store_wait<0>();
}
}
} else {
#pragma unroll
for (int i_4 = 0; i_4 < 2; ++i_4) {
#pragma unroll
for (int i_5 = 0; i_5 < 16; ++i_5) {
for (int vec_1 = 0; vec_1 < 2; ++vec_1) {
uint2 __2;
float4 v__1 = *(float4*)(C_local + (((i_4 * 128) + (i_5 * 8)) + (vec_1 * 4)));
(reinterpret_cast<__nv_bfloat162*>(&__2))[0] = __float22bfloat162_rn(((float2*)(&v__1))[0]);
(reinterpret_cast<__nv_bfloat162*>(&__2))[1] = __float22bfloat162_rn(((float2*)(&v__1))[1]);
*(uint2*)(C_shared_local_cast + (vec_1 * 4)) = __2;
}
*(uint4*)(((bfloat16_t*)buf_dyn_shmem) + (((((((i_5 >> 3) * 8192) + (((int)threadIdx.x) * 64)) + (((((i_5 & 7) >> 2) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 32)) + (((((i_5 & 3) >> 1) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 16)) + ((((i_5 & 1) + (((int)threadIdx.x) & 1)) & 1) * 8)) + 90112)) = *(uint4*)(C_shared_local_cast + 0);
}
}
}
}
}
}
}
}
if ((((int)threadIdx.x) >> 5) == 0) {
tl::tmem_deallocate((&(C_tmem_1[0])), 256);
tl::tmem_deallocate((&(C_tmem_0[0])), 256);
}
}
which is both strange and incorrect (transformed {sts+tma store} to {if (tx == 128) {sts+tma store} else {sts}}, lacked necessary thread sync)Expected behavior
If we disable the loopunswitching pass, the generated cuda is correct:
} else {
if (128 <= ((int)threadIdx.x)) {
#pragma unroll
for (int w_2 = 0; w_2 < 14; ++w_2) {
if (((w_2 * 37) + (((int)blockIdx.x) >> 2)) < 512) {
tmem_full[(w_2 & 1)].wait(((w_2 >> 1) & 1));
tl::__sync_thread_partial<1, 128>();
if ((w_2 & 1) == 0) {
tl::tcgen05_ld_32dp32bNx<256, false>(C_tmem_0[0], 0, (&(C_local[0])));
} else {
tl::tcgen05_ld_32dp32bNx<256, false>(C_tmem_1[0], 0, (&(C_local[0])));
}
tmem_empty[(w_2 & 1)].arrive();
#pragma unroll
for (int i_1 = 0; i_1 < 2; ++i_1) {
tl::__sync_thread_partial<3, 128>();
#pragma unroll
for (int i_2 = 0; i_2 < 16; ++i_2) {
for (int vec = 0; vec < 2; ++vec) {
uint2 __1;
float4 v_ = *(float4*)(C_local + (((i_1 * 128) + (i_2 * 8)) + (vec * 4)));
(reinterpret_cast<__nv_bfloat162*>(&__1))[0] = __float22bfloat162_rn(((float2*)(&v_))[0]);
(reinterpret_cast<__nv_bfloat162*>(&__1))[1] = __float22bfloat162_rn(((float2*)(&v_))[1]);
*(uint2*)(C_shared_local_cast + (vec * 4)) = __1;
}
*(uint4*)(((bfloat16_t*)buf_dyn_shmem) + (((((((i_2 >> 3) * 8192) + (((int)threadIdx.x) * 64)) + (((((i_2 & 7) >> 2) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 32)) + (((((i_2 & 3) >> 1) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 16)) + ((((i_2 & 1) + (((int)threadIdx.x) & 1)) & 1) * 8)) + 90112)) = *(uint4*)(C_shared_local_cast + 0);
}
tl::fence_proxy_async();
tl::__sync_thread_partial<3, 128>();
if (((int)threadIdx.x) == 128) {
#pragma unroll
for (int i_3 = 0; i_3 < 2; ++i_3) {
tl::tma_store(C_desc, (&(((bfloat16_t*)buf_dyn_shmem)[((i_3 * 8192) + 98304)])), (((((((w_2 * 37) + (((int)blockIdx.x) >> 2)) >> 7) * 2048) + ((((w_2 * 148) + ((int)blockIdx.x)) & 7) * 256)) + (i_1 * 128)) + (i_3 * 64)), (((((w_2 * 37) + (((int)blockIdx.x) >> 2)) & 127) >> 1) * 128));
tl::tma_store_arrive();
tl::tma_store_wait<0>();
}
}
}
}
}
}
}
}
if ((((int)threadIdx.x) >> 5) == 0) {
tl::tmem_deallocate((&(C_tmem_1[0])), 256);
tl::tmem_deallocate((&(C_tmem_0[0])), 256);
}
}
Additional context
No response