Skip to content

[BUG] Loop unswitching pass causes buggy codegen #1804

@Rachmanino

Description

@Rachmanino

Required prerequisites

What version of TileLang are you using?

0.1.7.post3+cuda.gitc47d6df3

System information

import sys, tilelang, torch
print(sys.version, sys.platform)
3.12.0 | packaged by Anaconda, Inc. | (main, Oct 2 2023, 17:29:18) [GCC 11.2.0] linux
print(tilelang.version)
0.1.7.post3+cuda.gitc47d6df3
print(torch.version)
2.10.0+cu128

The machine is B200 provided by @Hamerlate

Problem description

The LoopUnswitching pass seems to cause incorrect codegen. Furthermore, the result returns to normal when this pass is disabled. See the explanation below for more info.

Reproducible example code

Note that this is reported from pr #1774.

import torch
import tilelang
import tilelang.language as T
from tilelang.carver.arch import driver
from tilelang.profiler import do_bench
tilelang.disable_cache()

@tilelang.jit
def gemm(
    A,
    B,
    block_M,
    block_N,
    store_block_N,  # block_N for C_shared
    block_K,
    in_dtype,
    out_dtype,
    accum_dtype,
    num_stages,
    use_tma_store=True,
):
    M, N, K = T.const("M, N, K")

    A: T.Tensor[[M, K], in_dtype]
    B: T.Tensor[[K, N], in_dtype]
    C = T.empty((M, N), out_dtype)

    sm_num = driver.get_num_sms()
    m_blocks = T.ceildiv(M, block_M)
    n_blocks = T.ceildiv(N, block_N)
    assert K % (2 * block_K) == 0  # for simplicity
    k_blocks = T.ceildiv(K, block_K)
    waves = T.ceildiv(m_blocks * n_blocks, sm_num)
    group_size = 8

    with T.Kernel(sm_num, threads=256) as (block_id):
        A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype)
        B_shared = T.alloc_shared((num_stages, block_K, block_N), in_dtype)
        C_tmem_0 = T.alloc_tmem([block_M, block_N], accum_dtype)
        C_tmem_1 = T.alloc_tmem([block_M, block_N], accum_dtype)
        C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
        C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype)
        C_shared = T.alloc_shared((block_M, store_block_N), out_dtype)
        loaded = T.alloc_barrier([32] * num_stages)
        consumed = T.alloc_barrier([1] * num_stages)
        tmem_full = T.alloc_barrier([1] * 2)
        tmem_empty = T.alloc_barrier([128] * 2)

        tx = T.get_thread_binding()

        if tx < 32:  # warp 0: issue tma
            for w in T.unroll(waves):
                tile_id = sm_num * w + block_id
                bx = (tile_id // group_size) % m_blocks
                by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size

                if bx * block_M < M and by * block_N < N:
                    for k in T.serial(k_blocks):
                        T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1)
                        T.copy(
                            A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], A_shared[k % num_stages, :, :]
                        )  # cannot use BufferLoad here
                        T.copy(B[k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared[k % num_stages, :, :])
                        T.mbarrier_arrive(loaded[k % num_stages])

        elif tx < 64:  # warp 1: issue tcgen5
            for w in T.unroll(waves):
                tile_id = sm_num * w + block_id
                bx = (tile_id // group_size) % m_blocks
                by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size

                if bx * block_M < M and by * block_N < N:
                    T.mbarrier_wait_parity(tmem_empty[w & 1], ((w // 2) & 1) ^ 1)
                    for k in T.serial(k_blocks):
                        T.mbarrier_wait_parity(loaded[k % num_stages], (k // num_stages) & 1)
                        if w & 1 == 0:
                            T.gemm(
                                A_shared[k % num_stages, :, :],
                                B_shared[k % num_stages, :, :],
                                C_tmem_0,
                                False,
                                False,
                                mbar=consumed[k % num_stages],
                                wg_wait=-1,
                                clear_accum=k == 0,
                            )
                        else:
                            T.gemm(
                                A_shared[k % num_stages, :, :],
                                B_shared[k % num_stages, :, :],
                                C_tmem_1,
                                False,
                                False,
                                mbar=consumed[k % num_stages],
                                wg_wait=-1,
                                clear_accum=k == 0,
                            )
                    T.tcgen05_mma_arrive(tmem_full[w & 1])

        elif 128 <= tx < 256:  # warp 4~7: epilogue
            for w in T.unroll(waves):
                tile_id = sm_num * w + block_id
                bx = (tile_id // group_size) % m_blocks
                by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size

                if bx * block_M < M and by * block_N < N:
                    T.mbarrier_wait_parity(tmem_full[w & 1], (w // 2) & 1)
                    T.sync_threads(1, 128)
                    if (w & 1) == 0:
                        T.copy(C_tmem_0, C_local)
                    else:
                        T.copy(C_tmem_1, C_local)
                    T.mbarrier_arrive(tmem_empty[w & 1])

                    if use_tma_store:
                        for i in T.unroll(T.ceildiv(block_N, store_block_N)):
                            T.copy(C_local[:, i * store_block_N : (i + 1) * store_block_N], C_shared)
                            T.copy(C_shared, C[bx * block_M, by * block_N + i * store_block_N])
                    else:
                        T.copy(C_local, C_local_cast)
                        T.copy(C_local_cast, C[bx * block_M, by * block_N])
    return C


def main():
    M, N, K = 8192, 8192, 8192
    block_M, block_N, block_K = 128, 256, 64
    store_block_N = 128
    in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float
    num_stages = 4

    a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
    b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
    print(gemm.get_kernel_source(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages))
    c = gemm(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages)

    ref_c = (a.to(torch.float) @ b.to(torch.float)).to(torch.bfloat16)
    torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
    print("All checks passed. ✅")

    tl_latency = do_bench(
        lambda: gemm(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages), backend="cupti"
    )
    torch_latency = do_bench(lambda: a @ b, backend="cupti")
    print(f"Tilelang latency: {tl_latency} ms")
    print(f"Flops: {2 * M * N * K / (tl_latency / 1e3) / 1e12} TFLOPS")
    print(f"Torch latency: {torch_latency} ms")
    print(f"Flops: {2 * M * N * K / (torch_latency / 1e3) / 1e12} TFLOPS")


if __name__ == "__main__":
    main()

To successfully enable codegen, you need to modify tilelang src as shown below, and run on a sm100 machine:

diff --git a/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/tileop/gemm/gemm_tcgen05.py
index 32bef1cb..90debeb1 100644
--- a/tilelang/tileop/gemm/gemm_tcgen05.py
+++ b/tilelang/tileop/gemm/gemm_tcgen05.py
@@ -96,7 +96,8 @@ class GemmTCGEN5(GemmBase):
         if mbar == 0:
             raise ValueError("TCGEN5MMA requires a valid mbarrier")
 
-        mbarptr = mbar.access_ptr("rw")
+        from tilelang.utils.language import retrieve_ptr
+        mbarptr = retrieve_ptr(mbar, "rw")
 
         C_coords = self.C_coords
         if len(C_coords) != 2:

Also, it is normal that the generated cuda code could not pass nvcc compilation. (Check out #1774 for more info)

Traceback

The generated cuda is sth that ends with:


    } else {
      if (128 <= ((int)threadIdx.x)) {
        #pragma unroll
        for (int w_2 = 0; w_2 < 14; ++w_2) {
          if (((w_2 * 37) + (((int)blockIdx.x) >> 2)) < 512) {
            tmem_full[(w_2 & 1)].wait(((w_2 >> 1) & 1));
            tl::__sync_thread_partial<1, 128>();
            if ((w_2 & 1) == 0) {
              tl::tcgen05_ld_32dp32bNx<256, false>(C_tmem_0[0], 0, (&(C_local[0])));
            } else {
              tl::tcgen05_ld_32dp32bNx<256, false>(C_tmem_1[0], 0, (&(C_local[0])));
            }
            tmem_empty[(w_2 & 1)].arrive();
            tl::__sync_thread_partial<3, 128>();
            if (((int)threadIdx.x) == 128) {
              #pragma unroll
              for (int i_1 = 0; i_1 < 2; ++i_1) {
                #pragma unroll
                for (int i_2 = 0; i_2 < 16; ++i_2) {
                  for (int vec = 0; vec < 2; ++vec) {
                    uint2 __1;
                    float4 v_ = *(float4*)(C_local + (((i_1 * 128) + (i_2 * 8)) + (vec * 4)));
                    (reinterpret_cast<__nv_bfloat162*>(&__1))[0] = __float22bfloat162_rn(((float2*)(&v_))[0]);
                    (reinterpret_cast<__nv_bfloat162*>(&__1))[1] = __float22bfloat162_rn(((float2*)(&v_))[1]);
                    *(uint2*)(C_shared_local_cast + (vec * 4)) = __1;
                  }
                  *(uint4*)(((bfloat16_t*)buf_dyn_shmem) + (((((((i_2 >> 3) * 8192) + (((int)threadIdx.x) * 64)) + (((((i_2 & 7) >> 2) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 32)) + (((((i_2 & 3) >> 1) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 16)) + ((((i_2 & 1) + (((int)threadIdx.x) & 1)) & 1) * 8)) + 90112)) = *(uint4*)(C_shared_local_cast + 0);
                }
                tl::fence_proxy_async();
                #pragma unroll
                for (int i_3 = 0; i_3 < 2; ++i_3) {
                  tl::tma_store(C_desc, (&(((bfloat16_t*)buf_dyn_shmem)[((i_3 * 8192) + 98304)])), (((((((w_2 * 37) + (((int)blockIdx.x) >> 2)) >> 7) * 2048) + ((((w_2 * 148) + ((int)blockIdx.x)) & 7) * 256)) + (i_1 * 128)) + (i_3 * 64)), (((((w_2 * 37) + (((int)blockIdx.x) >> 2)) & 127) >> 1) * 128));
                  tl::tma_store_arrive();
                  tl::tma_store_wait<0>();
                }
              }
            } else {
              #pragma unroll
              for (int i_4 = 0; i_4 < 2; ++i_4) {
                #pragma unroll
                for (int i_5 = 0; i_5 < 16; ++i_5) {
                  for (int vec_1 = 0; vec_1 < 2; ++vec_1) {
                    uint2 __2;
                    float4 v__1 = *(float4*)(C_local + (((i_4 * 128) + (i_5 * 8)) + (vec_1 * 4)));
                    (reinterpret_cast<__nv_bfloat162*>(&__2))[0] = __float22bfloat162_rn(((float2*)(&v__1))[0]);
                    (reinterpret_cast<__nv_bfloat162*>(&__2))[1] = __float22bfloat162_rn(((float2*)(&v__1))[1]);
                    *(uint2*)(C_shared_local_cast + (vec_1 * 4)) = __2;
                  }
                  *(uint4*)(((bfloat16_t*)buf_dyn_shmem) + (((((((i_5 >> 3) * 8192) + (((int)threadIdx.x) * 64)) + (((((i_5 & 7) >> 2) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 32)) + (((((i_5 & 3) >> 1) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 16)) + ((((i_5 & 1) + (((int)threadIdx.x) & 1)) & 1) * 8)) + 90112)) = *(uint4*)(C_shared_local_cast + 0);
                }
              }
            }
          }
        }
      }
    }
  }
  if ((((int)threadIdx.x) >> 5) == 0) {
    tl::tmem_deallocate((&(C_tmem_1[0])), 256);
    tl::tmem_deallocate((&(C_tmem_0[0])), 256);
  }
}


which is both strange and incorrect (transformed {sts+tma store} to {if (tx == 128) {sts+tma store} else {sts}}, lacked necessary thread sync)

Expected behavior

If we disable the loopunswitching pass, the generated cuda is correct:

   } else {
      if (128 <= ((int)threadIdx.x)) {
        #pragma unroll
        for (int w_2 = 0; w_2 < 14; ++w_2) {
          if (((w_2 * 37) + (((int)blockIdx.x) >> 2)) < 512) {
            tmem_full[(w_2 & 1)].wait(((w_2 >> 1) & 1));
            tl::__sync_thread_partial<1, 128>();
            if ((w_2 & 1) == 0) {
              tl::tcgen05_ld_32dp32bNx<256, false>(C_tmem_0[0], 0, (&(C_local[0])));
            } else {
              tl::tcgen05_ld_32dp32bNx<256, false>(C_tmem_1[0], 0, (&(C_local[0])));
            }
            tmem_empty[(w_2 & 1)].arrive();
            #pragma unroll
            for (int i_1 = 0; i_1 < 2; ++i_1) {
              tl::__sync_thread_partial<3, 128>();
              #pragma unroll
              for (int i_2 = 0; i_2 < 16; ++i_2) {
                for (int vec = 0; vec < 2; ++vec) {
                  uint2 __1;
                  float4 v_ = *(float4*)(C_local + (((i_1 * 128) + (i_2 * 8)) + (vec * 4)));
                  (reinterpret_cast<__nv_bfloat162*>(&__1))[0] = __float22bfloat162_rn(((float2*)(&v_))[0]);
                  (reinterpret_cast<__nv_bfloat162*>(&__1))[1] = __float22bfloat162_rn(((float2*)(&v_))[1]);
                  *(uint2*)(C_shared_local_cast + (vec * 4)) = __1;
                }
                *(uint4*)(((bfloat16_t*)buf_dyn_shmem) + (((((((i_2 >> 3) * 8192) + (((int)threadIdx.x) * 64)) + (((((i_2 & 7) >> 2) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 32)) + (((((i_2 & 3) >> 1) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 16)) + ((((i_2 & 1) + (((int)threadIdx.x) & 1)) & 1) * 8)) + 90112)) = *(uint4*)(C_shared_local_cast + 0);
              }
              tl::fence_proxy_async();
              tl::__sync_thread_partial<3, 128>();
              if (((int)threadIdx.x) == 128) {
                #pragma unroll
                for (int i_3 = 0; i_3 < 2; ++i_3) {
                  tl::tma_store(C_desc, (&(((bfloat16_t*)buf_dyn_shmem)[((i_3 * 8192) + 98304)])), (((((((w_2 * 37) + (((int)blockIdx.x) >> 2)) >> 7) * 2048) + ((((w_2 * 148) + ((int)blockIdx.x)) & 7) * 256)) + (i_1 * 128)) + (i_3 * 64)), (((((w_2 * 37) + (((int)blockIdx.x) >> 2)) & 127) >> 1) * 128));
                  tl::tma_store_arrive();
                  tl::tma_store_wait<0>();
                }
              }
            }
          }
        }
      }
    }
  }
  if ((((int)threadIdx.x) >> 5) == 0) {
    tl::tmem_deallocate((&(C_tmem_1[0])), 256);
    tl::tmem_deallocate((&(C_tmem_0[0])), 256);
  }
}

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions