Skip to content

[BUG] Using T.var for index calculations leads to incorrect pipelined code and errors #1792

@Dhruv88

Description

@Dhruv88

Required prerequisites

What version of TileLang are you using?

0.1.7.post2

System information

('3.12.11 | packaged by Anaconda, Inc. | (main, Jun 5 2025, 13:09:17) [GCC 11.2.0]', 'linux')
0.1.7.post2
2.9.0+cu128

H100 GPUs

Problem description

The paged decode example in examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py handles a specific case where the of particular sequence are contiguous. Hence, I implemented a more general case where there is on such constraint, and a block can be present anywhere in the case compliant to vllm api. The reference code is given below. There are certain calculations in the key and value loading code which are repeated. I tried to get these stored in T.var to avoid repetition. However, that either causes error or makes the pipelined code in CUDA weird. Let me explain both:

  1. Currently I have T.var in the code:
final_scale = T.alloc_var(accum_dtype)
cur_kv_head = T.alloc_var("int32")
pages_per_blockN = T.alloc_var("int32")

The first one is fine and is automatically created even if don't explicitly define it as T.var.
The other two are problematic. With them and number of pipeline stages set to 1 or 2 in the generated CUDA code the key and value loading loop will shift to the consumer block which does the computation. Without them it correctly is part of producer block along with the query loading code. Note that though the code is weird it passes the correctness checks and also is as fast as vllm version of paged flash attention 3 when num_stages is 1 and only slightly slower (<1-2ms) when num_stages is 2. But maybe a correct code with key and value logic placed in producer block gives a better time.
2. There is a fourth T.var which is commented out:

block_table_offset = T.alloc_var("int32")

If uncommented it leads to a compilation error. The traceback is given below.

Reproducible example code

The Python snippets:

import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from tilelang.engine.param import KernelParam 
import argparse
import itertools
from tvm import tir, ir
from ..kernel_utils import *

torch.random.manual_seed(0)


def get_configs():
    block_N = [64, 128]
    block_H = [64]
    num_split = [2, 4, 8]
    num_stages = [1, 2, 3]
    threads = [128]
    _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))

    configs = [{
        'block_N': c[0],
        'block_H': c[1],
        'num_split': c[2],
        'num_stages': c[3],
        'threads': c[4]
    } for c in _configs]
    return configs
def get_pass_config():
    return { "tl.disable_safe_memory_legalize": True}

@tilelang.jit(out_idx=[], execution_backend="tvm_ffi", pass_configs=get_pass_config())
def flashattn(heads, groups, max_blocks, num_blocks, block_size, dim, batch, block_N, block_H, num_split, num_stages, threads):
    log_scale = 1.44269504  # log2(e)
    dtype = "float16"
    accum_dtype = "float"
    kv_group_num = heads // groups
    shape_q = [batch, heads, dim]
    shape_k = [num_blocks, block_size, groups, dim]
    shape_v = [num_blocks, block_size, groups, dim]
    shape_o = [batch, heads, dim]
    shape_bt = [batch, max_blocks]
    part_shape = [batch, heads, num_split, dim]
    valid_block_H = min(block_H, kv_group_num)
    assert block_size <= block_N, "block_size must be less than or equal to block_N"

    @T.macro
    def flash_attn_split(
            Q: T.Tensor(shape_q, dtype),
            K: T.Tensor(shape_k, dtype),
            V: T.Tensor(shape_v, dtype),
            block_table: T.Tensor(shape_bt, dtype="int32"),
            cache_seqlen: T.Tensor([batch], dtype="int32"),
            softmax_scale: T.float32,
            glse: T.Tensor([batch, heads, num_split], dtype),
            Output_partial: T.Tensor(part_shape, dtype),
    ):
        with T.Kernel(
                batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
            Q_shared = T.alloc_shared([block_H, dim], dtype)
            K_shared = T.alloc_shared([block_N, dim], dtype)
            V_shared = T.alloc_shared([block_N, dim], dtype)
            O_shared = T.alloc_shared([valid_block_H, dim], dtype)
            acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
            acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
            acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
            scores_max = T.alloc_fragment([block_H], accum_dtype)
            scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
            scores_scale = T.alloc_fragment([block_H], accum_dtype)
            scores_sum = T.alloc_fragment([block_H], accum_dtype)
            logsum = T.alloc_fragment([block_H], accum_dtype)
            final_scale = T.alloc_var(accum_dtype)
            cur_kv_head = T.alloc_var("int32")
            pages_per_blockN = T.alloc_var("int32")
            # block_table_offset = T.alloc_var("int32")

            bid = bx
            hid = by
            sid = bz
            cur_kv_head = hid // (kv_group_num // valid_block_H)
            final_scale = log_scale * softmax_scale
            pages_per_blockN = T.ceildiv(block_N, block_size)

            T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
            T.fill(acc_o, 0)
            T.fill(logsum, 0)
            T.fill(scores_max, -T.infinity(accum_dtype))

            total_chunks = T.ceildiv(cache_seqlen[bid], block_N)
            base_chunks_per_split = T.floordiv(total_chunks, num_split)
            remainder_chunks = T.floormod(total_chunks, num_split)
            final_chunks = base_chunks_per_split + T.if_then_else(sid < remainder_chunks, 1, 0)
            prev_split_chunks = base_chunks_per_split * sid + T.min(sid, remainder_chunks)
            start_idx = prev_split_chunks * block_N
            for k in T.Pipelined(final_chunks, num_stages=num_stages):
                block_table_offset = (prev_split_chunks + k) * pages_per_blockN
                for i,j in T.Parallel(block_N, dim):
                    page_idx = T.floordiv(i, block_size)
                    token_idx = start_idx + k * block_N + i
                    K_shared[i,j] = T.if_then_else(token_idx < cache_seqlen[bid], K[block_table[bid, block_table_offset + page_idx], T.floormod(i, block_size), cur_kv_head, j], 0)
                T.clear(acc_s)
                T.gemm(
                    Q_shared,
                    K_shared,
                    acc_s,
                    transpose_B=True,
                    policy=T.GemmWarpPolicy.FullRow)
                for i, j in T.Parallel(block_H, block_N):
                    acc_s[i, j] = T.if_then_else(start_idx + k * block_N + j < cache_seqlen[bid], acc_s[i, j],
                                                    -T.infinity(accum_dtype))
                T.copy(scores_max, scores_max_prev)
                T.fill(scores_max, -T.infinity(accum_dtype))
                T.reduce_max(acc_s, scores_max, dim=1, clear=False)
                for i in T.Parallel(block_H):
                    scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
                for i in T.Parallel(block_H):
                    scores_scale[i] = T.exp2(scores_max_prev[i] * final_scale - scores_max[i] * final_scale)
                for i, j in T.Parallel(block_H, block_N):
                    acc_s[i, j] = T.exp2(acc_s[i, j] * final_scale - scores_max[i] * final_scale)
                T.reduce_sum(acc_s, scores_sum, dim=1)
                for i in T.Parallel(block_H):
                    logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
                T.copy(acc_s, acc_s_cast)
                for i, j in T.Parallel(block_H, dim):
                    acc_o[i, j] *= scores_scale[i]
                for i,j in T.Parallel(block_N, dim):
                    page_idx = T.floordiv(i, block_size)
                    token_idx = start_idx + k * block_N + i
                    V_shared[i,j] = T.if_then_else(token_idx < cache_seqlen[bid], V[block_table[bid, block_table_offset + page_idx], T.floormod(i, block_size), cur_kv_head, j], 0)
                T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
            if final_chunks > 0:
                for i, j in T.Parallel(block_H, dim):
                    acc_o[i, j] /= logsum[i]
            for i in T.Parallel(block_H):
                logsum[i] = T.log2(logsum[i]) + scores_max[i] * final_scale

            T.copy(logsum[:valid_block_H],
                    glse[bid, hid * valid_block_H:(hid + 1) * valid_block_H, sid])
            T.copy(acc_o[:valid_block_H, :], O_shared)
            T.copy(O_shared, Output_partial[bid, hid * valid_block_H:(hid + 1) * valid_block_H,
                                            sid, :])

    @T.macro
    def combine(
            glse: T.Tensor([batch, heads, num_split], dtype),
            Output_partial: T.Tensor(part_shape, dtype),
            Output: T.Tensor(shape_o, dtype),
    ):
        with T.Kernel(heads, batch, threads=128) as (by, bz):
            po_local = T.alloc_fragment([dim], dtype)
            o_accum_local = T.alloc_fragment([dim], accum_dtype)
            lse_local = T.alloc_fragment([num_split, 128], dtype)
            lse_logsum_local = T.alloc_fragment([128], accum_dtype)
            lse_max_local = T.alloc_fragment([128], accum_dtype)
            scale_local = T.alloc_fragment([128], accum_dtype)

            T.clear(lse_logsum_local)
            T.clear(o_accum_local)
            for k, j in T.Parallel(num_split, 128):
                lse_local[k, j] = glse[bz, by, k]
            T.reduce_max(lse_local, lse_max_local, dim=0, clear=True)
            for k in T.serial(num_split):
                for j in T.Parallel(128):
                    lse_logsum_local[j] += T.exp2(lse_local[k, j] - lse_max_local[j])
            for j in T.Parallel(128):
                lse_logsum_local[j] = T.log2(lse_logsum_local[j]) + lse_max_local[j]
            for k in T.serial(num_split):
                for i in T.Parallel(dim):
                    po_local[i] = Output_partial[bz, by, k, i]
                for j in T.Parallel(128):
                    scale_local[j] = T.exp2(lse_local[k, j] - lse_logsum_local[j])
                # Note: Pay attention to dim and the number of threads in Parallel
                for i in T.Parallel(dim):
                    o_accum_local[i] += po_local[i] * scale_local[i]
            for i in T.Parallel(dim):
                Output[bz, by, i] = o_accum_local[i]

    @T.prim_func
    def flashattn_gqa_decode_split(
            Q: T.Tensor(shape_q, dtype),
            K: T.Tensor(shape_k, dtype),
            V: T.Tensor(shape_v, dtype),
            block_table: T.Tensor(shape_bt, dtype="int32"),
            cache_seqlen: T.Tensor([batch], dtype="int32"),
            softmax_scale: T.float32,
            glse: T.Tensor([batch, heads, num_split], dtype),
            Output_partial: T.Tensor(part_shape, dtype),
            Output: T.Tensor(shape_o, dtype),
    ):
        flash_attn_split(Q, K, V, block_table, cache_seqlen, softmax_scale, glse, Output_partial)
        combine(glse, Output_partial, Output)

    return flashattn_gqa_decode_split

class FlashAttnGQADecodeKVCache(torch.nn.Module):
    def __init__(self, heads: int, groups: int, dim: int, max_blocks: int, num_blocks: int, block_size: int):
        super().__init__()
        self.heads = heads
        self.groups = groups
        self.dim = dim
        self.max_blocks = max_blocks
        self.num_blocks = num_blocks
        self.block_size = block_size
        self.block_N = 128
        self.block_H = 64
        self.num_split = 8
        self.kernel = flashattn(
            heads=heads,
            groups=groups,
            dim=dim,
            max_blocks=max_blocks,
            num_blocks=T.symbolic("num_blocks"),       
            block_size=block_size,
            batch=T.symbolic("batch"),
            block_N=self.block_N,
            block_H=self.block_H,
            num_split=self.num_split,
            num_stages=1,
            threads=128)

    def forward(self, query, key, value, block_table, cache_seqlen, softmax_scale):
        batch_size = query.size(0)
        glse = torch.empty(batch_size, self.heads, 8, device=query.device, dtype=query.dtype)
        output_partial = torch.empty(batch_size, self.heads, 8, self.dim, device=query.device, dtype=query.dtype)
        output = torch.empty(batch_size, self.heads, self.dim, device=query.device, dtype=query.dtype)  
        self.kernel(query, key, value, block_table, cache_seqlen, softmax_scale, glse, output_partial, output)
        return output

def ref_program_fa(query, key, value, block_table, cache_seqlens, softmax_scale, max_seqlen_k, cu_seqlens_q):
    from vllm.vllm_flash_attn import flash_attn_varlen_func
    out = flash_attn_varlen_func(query, key, value, 1, cu_seqlens_q, max_seqlen_k, seqused_k=cache_seqlens, causal=False, softmax_scale=softmax_scale, block_table=block_table)
    return out


def main(max_batch: int = 1,
         heads: int = 32,
         groups: int = 8,
         max_cache_seqlen: int = 8192,
         block_size: int = 32,
         max_model_len: int = 131072,
         dim: int = 128,
         run_type: str = "benchmark"  # "benchmark", "correctness"
         ):
    qk_flops = 2 * max_batch * heads * max_cache_seqlen * dim
    pv_flops = 2 * max_batch * heads * max_cache_seqlen * dim
    total_flops = qk_flops + pv_flops
    num_blocks = ((max_model_len * max_batch) + block_size-1) // block_size
    max_blocks = (max_model_len + block_size-1) // block_size
    kernel = FlashAttnGQADecodeKVCache(heads, groups, dim, max_blocks, num_blocks, block_size)

    if (run_type == "correctness"):
        num_runs = 100
    else:
        num_runs = 10

    lat_f_sum = lat_og_sum = 0.0
    cnt_f = cnt_og = 0

    for _ in range(num_runs):
        if run_type == "correctness":
            batch = torch.randint(1, max_batch + 1, (1,)).item()
            cache_seqlens = torch.randint(128, max_cache_seqlen, (batch,), device='cuda', dtype=torch.int32)
        else:
            batch = max_batch
            cache_seqlens = torch.full((batch,), max_cache_seqlen, device='cuda', dtype=torch.int32)
        q = torch.randn(batch, heads, dim, device="cuda", dtype=torch.float16)
        k = torch.randn(num_blocks, block_size, groups, dim, device="cuda", dtype=torch.float16)
        v = torch.randn(num_blocks, block_size, groups, dim, device="cuda", dtype=torch.float16)
        # Create block_table where each request gets random unique block ids
        block_table = torch.randperm(batch * max_blocks, device="cuda", dtype=torch.int32).reshape(batch, max_blocks) 
        max_seqlen_k = cache_seqlens.max().item()
        cu_seqlens_q = torch.zeros(batch + 1, device='cuda', dtype=torch.int32)
        cu_seqlens_q[1:] = cache_seqlens.cumsum(dim=0)
        

        if run_type == "correctness":
            o_ref = ref_program_fa(q, k, v, block_table, cache_seqlens, (1.0 / dim)**0.5, max_seqlen_k, cu_seqlens_q)
            o_og = kernel(q, k, v, block_table, cache_seqlens, (1.0 / dim)**0.5)
            assert_similar(o_ref, o_og, name="ref_og", assert_=True, print_=True)

        if run_type == "benchmark":
            try:
                latency_f = do_bench(ref_program_fa, n_warmup=1, n_repeat=5, input_tensors=[q, k, v, block_table, cache_seqlens, (1.0 / dim)**0.5, max_seqlen_k, cu_seqlens_q])
                lat_f_sum += latency_f
                cnt_f += 1
            except Exception:
                pass
            try:
                latency_og = do_bench(kernel, n_warmup=1, n_repeat=5, input_tensors=[q, k, v, block_table, cache_seqlens, (1.0 / dim)**0.5])
                lat_og_sum += latency_og
                cnt_og += 1
            except Exception:
                pass
        del q, k, v, cache_seqlens
        torch.cuda.empty_cache()

    if run_type == "benchmark":
        avg_f = lat_f_sum / cnt_f if cnt_f > 0 else 0.0
        avg_og = lat_og_sum / cnt_og if cnt_og > 0 else 0.0

        results = [f"{max_cache_seqlen}"]
        tflops_f = total_flops / avg_f * 1e-9 if avg_f > 0 else 0
        tflops_og = total_flops / avg_og * 1e-9 if avg_og > 0 else 0
        results.extend([f"{avg_f:.2f}", f"{tflops_f:.2f}"])
        results.extend([f"{avg_og:.2f}", f"{tflops_og:.2f}"])
        print(",".join(results))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--max_batch', type=int, default=4, help='batch size')
    parser.add_argument('--heads', type=int, default=32, help='heads')
    parser.add_argument('--groups', type=int, default=8, help='groups')
    parser.add_argument('--max_cache_seqlen', type=int, default=8192, help='max cache sequence length')
    parser.add_argument('--block_size', type=int, default=32, help='block size <= 128')
    parser.add_argument('--max_model_len', type=int, default=131072, help='max model length')
    parser.add_argument('--dim', type=int, default=128, help='dim')
    parser.add_argument('--run_type', type=str, default="benchmark", choices=["benchmark", "correctness"], help='Type of run: benchmark, correctness')
    args = parser.parse_args()
    main(args.max_batch, args.heads, args.groups, args.max_cache_seqlen, args.block_size, args.max_model_len, args.dim, args.run_type)

Traceback

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/dd/workspace/src/kernels/flash_decoding/gqa_decode_paged.py", line 322, in <module>
    main(args.max_batch, args.heads, args.groups, args.max_cache_seqlen, args.block_size, args.max_model_len, args.dim, args.run_type)
  File "/home/dd/workspace/src/kernels/flash_decoding/gqa_decode_paged.py", line 251, in main
    kernel = FlashAttnGQADecodeKVCache(heads, groups, dim, max_blocks, num_blocks, block_size)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dd/workspace/src/kernels/flash_decoding/gqa_decode_paged.py", line 209, in __init__
    self.kernel = flashattn(
                  ^^^^^^^^^^
  File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/jit/__init__.py", line 414, in __call__
    kernel = self.compile(*args, **kwargs, **tune_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/jit/__init__.py", line 349, in compile
    kernel_result = compile(
                    ^^^^^^^^
  File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/jit/__init__.py", line 98, in compile
    return cached(
           ^^^^^^^
  File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/cache/__init__.py", line 74, in cached
    return _dispatch_map[execution_backend].cached(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/cache/kernel_cache.py", line 204, in cached
    kernel = JITKernel(
             ^^^^^^^^^^
  File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/jit/kernel.py", line 137, in __init__
    adapter = self._compile_and_create_adapter(func, out_idx)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/jit/kernel.py", line 242, in _compile_and_create_adapter
    artifact = tilelang.lower(
               ^^^^^^^^^^^^^^^
  File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/engine/lower.py", line 270, in lower
    mod = OptimizeForTarget(mod, target)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/engine/phase.py", line 205, in OptimizeForTarget
    mod = tilelang.transform.InjectSoftwarePipeline()(mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/3rdparty/tvm/python/tvm/ir/transform.py", line 167, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "python/tvm_ffi/cython/function.pxi", line 923, in tvm_ffi.core.Function.__call__
  File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule) const
  File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  File "<unknown>", line 0, in tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  File "<unknown>", line 0, in tvm::tl::software_pipeline::PipelineInjector::Inject(tvm::tir::PrimFunc const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#15}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::BlockRealizeNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#14}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tl::software_pipeline::PipelineInjector::VisitStmt_(tvm::tir::BlockNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::BlockNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#10}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::SeqStmtNode const*)
  File "<unknown>", line 0, in tvm::ffi::ObjectPtr<tvm::ffi::Object> tvm::ffi::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::ffi::ObjectPtr<tvm::ffi::Object>, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#15}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::BlockRealizeNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#14}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tl::software_pipeline::PipelineInjector::VisitStmt_(tvm::tir::BlockNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::BlockNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#10}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::SeqStmtNode const*)
  File "<unknown>", line 0, in tvm::ffi::ObjectPtr<tvm::ffi::Object> tvm::ffi::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::ffi::ObjectPtr<tvm::ffi::Object>, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#3}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::IfThenElseNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#10}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::SeqStmtNode const*)
  File "<unknown>", line 0, in tvm::ffi::ObjectPtr<tvm::ffi::Object> tvm::ffi::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::ffi::ObjectPtr<tvm::ffi::Object>, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#14}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tl::software_pipeline::PipelineInjector::VisitStmt_(tvm::tir::BlockNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::BlockNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#10}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::SeqStmtNode const*)
  File "<unknown>", line 0, in tvm::ffi::ObjectPtr<tvm::ffi::Object> tvm::ffi::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::ffi::ObjectPtr<tvm::ffi::Object>, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#1}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::LetStmtNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#10}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::SeqStmtNode const*)
  File "<unknown>", line 0, in tvm::ffi::ObjectPtr<tvm::ffi::Object> tvm::ffi::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::ffi::ObjectPtr<tvm::ffi::Object>, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#4}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tl::software_pipeline::PipelineInjector::VisitStmt_(tvm::tir::ForNode const*)
  File "<unknown>", line 0, in tvm::runtime::detail::LogFatal::Entry::Finalize()
ValueError: Check failed: src_info.stage <= dst_info.stage (1 vs. 0) : statement with T.block("", no_realize=True):
    batch = T.int32()
    cache_seqlen = T.Buffer((batch,), "int32", strides=(1,))
    bid = T.int32()
    num_blocks = T.int32()
    K = T.Buffer((num_blocks, 64, 8, 128), "float16", strides=(65536, 1024, 128, 1))
    tx = T.int32()
    cur_kv_head = T.Buffer((1,), "int32", scope="local.var")
    block_table = T.Buffer((batch, 2048), "int32", strides=(2048, 1))
    block_table_offset = T.Buffer((1,), "int32", scope="local.var")
    T.reads(cache_seqlen[bid], K[0:num_blocks, tx // 16:tx // 16 + 57, cur_kv_head[0], tx % 16 * 8:tx % 16 * 8 + 8], block_table[bid, block_table_offset[0]:block_table_offset[0] + 2], block_table_offset[0], cur_kv_head[0])
    K_shared = T.Buffer((1, 2, 16, 512), "float16", scope="shared.dyn")
    k = T.int32()
    T.writes(K_shared[T.FloorMod(k, 1), tx % 16 // 8, 0:16, tx // 16 * 64 + (tx // 64 + tx % 8 // 4) % 2 * 32 + (tx % 64 // 32 + tx % 4 // 2) % 2 * 16 + (tx % 32 // 16 + tx % 2) % 2 * 8:tx // 16 * 64 + (tx // 64 + tx % 8 // 4) % 2 * 32 + (tx % 64 // 32 + tx % 4 // 2) % 2 * 16 + (tx % 32 // 16 + tx % 2) % 2 * 8 + 8])
    for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}):
        for vec in T.vectorized(8):
            sid = T.int32()
            total_chunks = T.int32()
            K_shared[T.FloorMod(k, 1), tx % 16 // 8, i, tx // 16 * 64 + (tx // 64 + tx % 8 // 4) % 2 * 32 + (tx % 64 // 32 + tx % 4 // 2) % 2 * 16 + (tx % 32 // 16 + tx % 2) % 2 * 8 + vec] = T.if_then_else(T.min(sid, total_chunks % 8) * 128 + k * 128 + total_chunks // 8 * sid * 128 + i * 8 + tx // 16 < cache_seqlen[bid], K[block_table[bid, i // 8 + block_table_offset[0]], i % 8 * 8 + tx // 16, cur_kv_head[0], tx % 16 * 8 + vec], T.float16(0.0)) in stage 0 cannot depends on statement with T.block("", no_realize=True):
    pages_per_blockN = T.Buffer((1,), "int32", scope="local.var")
    T.reads(pages_per_blockN[0])
    block_table_offset = T.Buffer((1,), "int32", scope="local.var")
    T.writes(block_table_offset[0])
    sid = T.int32()
    total_chunks = T.int32()
    k = T.int32()
    block_table_offset[0] = (T.min(sid, total_chunks % 8) + total_chunks // 8 * sid + k) * pages_per_blockN[0] in a later stage 1

Expected behavior

The T.var are just simple calculation which should not affect the pipeline logic. Given the CUDA without these variables it easy enough to extract these variables for storing common calculations in registers. So, ideally tilelang should generate similar be able to do this as well and not result in error or weird code.

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions