-
Notifications
You must be signed in to change notification settings - Fork 452
Description
Required prerequisites
- I have read the documentation https://tilelang.com.
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of TileLang are you using?
0.1.7.post2
System information
('3.12.11 | packaged by Anaconda, Inc. | (main, Jun 5 2025, 13:09:17) [GCC 11.2.0]', 'linux')
0.1.7.post2
2.9.0+cu128
H100 GPUs
Problem description
The paged decode example in examples/flash_decoding/example_gqa_decode_varlen_logits_paged.py handles a specific case where the of particular sequence are contiguous. Hence, I implemented a more general case where there is on such constraint, and a block can be present anywhere in the case compliant to vllm api. The reference code is given below. There are certain calculations in the key and value loading code which are repeated. I tried to get these stored in T.var to avoid repetition. However, that either causes error or makes the pipelined code in CUDA weird. Let me explain both:
- Currently I have T.var in the code:
final_scale = T.alloc_var(accum_dtype)
cur_kv_head = T.alloc_var("int32")
pages_per_blockN = T.alloc_var("int32")
The first one is fine and is automatically created even if don't explicitly define it as T.var.
The other two are problematic. With them and number of pipeline stages set to 1 or 2 in the generated CUDA code the key and value loading loop will shift to the consumer block which does the computation. Without them it correctly is part of producer block along with the query loading code. Note that though the code is weird it passes the correctness checks and also is as fast as vllm version of paged flash attention 3 when num_stages is 1 and only slightly slower (<1-2ms) when num_stages is 2. But maybe a correct code with key and value logic placed in producer block gives a better time.
2. There is a fourth T.var which is commented out:
block_table_offset = T.alloc_var("int32")
If uncommented it leads to a compilation error. The traceback is given below.
Reproducible example code
The Python snippets:
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from tilelang.engine.param import KernelParam
import argparse
import itertools
from tvm import tir, ir
from ..kernel_utils import *
torch.random.manual_seed(0)
def get_configs():
block_N = [64, 128]
block_H = [64]
num_split = [2, 4, 8]
num_stages = [1, 2, 3]
threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))
configs = [{
'block_N': c[0],
'block_H': c[1],
'num_split': c[2],
'num_stages': c[3],
'threads': c[4]
} for c in _configs]
return configs
def get_pass_config():
return { "tl.disable_safe_memory_legalize": True}
@tilelang.jit(out_idx=[], execution_backend="tvm_ffi", pass_configs=get_pass_config())
def flashattn(heads, groups, max_blocks, num_blocks, block_size, dim, batch, block_N, block_H, num_split, num_stages, threads):
log_scale = 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // groups
shape_q = [batch, heads, dim]
shape_k = [num_blocks, block_size, groups, dim]
shape_v = [num_blocks, block_size, groups, dim]
shape_o = [batch, heads, dim]
shape_bt = [batch, max_blocks]
part_shape = [batch, heads, num_split, dim]
valid_block_H = min(block_H, kv_group_num)
assert block_size <= block_N, "block_size must be less than or equal to block_N"
@T.macro
def flash_attn_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_table: T.Tensor(shape_bt, dtype="int32"),
cache_seqlen: T.Tensor([batch], dtype="int32"),
softmax_scale: T.float32,
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
):
with T.Kernel(
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([valid_block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
final_scale = T.alloc_var(accum_dtype)
cur_kv_head = T.alloc_var("int32")
pages_per_blockN = T.alloc_var("int32")
# block_table_offset = T.alloc_var("int32")
bid = bx
hid = by
sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H)
final_scale = log_scale * softmax_scale
pages_per_blockN = T.ceildiv(block_N, block_size)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
total_chunks = T.ceildiv(cache_seqlen[bid], block_N)
base_chunks_per_split = T.floordiv(total_chunks, num_split)
remainder_chunks = T.floormod(total_chunks, num_split)
final_chunks = base_chunks_per_split + T.if_then_else(sid < remainder_chunks, 1, 0)
prev_split_chunks = base_chunks_per_split * sid + T.min(sid, remainder_chunks)
start_idx = prev_split_chunks * block_N
for k in T.Pipelined(final_chunks, num_stages=num_stages):
block_table_offset = (prev_split_chunks + k) * pages_per_blockN
for i,j in T.Parallel(block_N, dim):
page_idx = T.floordiv(i, block_size)
token_idx = start_idx + k * block_N + i
K_shared[i,j] = T.if_then_else(token_idx < cache_seqlen[bid], K[block_table[bid, block_table_offset + page_idx], T.floormod(i, block_size), cur_kv_head, j], 0)
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(start_idx + k * block_N + j < cache_seqlen[bid], acc_s[i, j],
-T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * final_scale - scores_max[i] * final_scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * final_scale - scores_max[i] * final_scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
for i,j in T.Parallel(block_N, dim):
page_idx = T.floordiv(i, block_size)
token_idx = start_idx + k * block_N + i
V_shared[i,j] = T.if_then_else(token_idx < cache_seqlen[bid], V[block_table[bid, block_table_offset + page_idx], T.floormod(i, block_size), cur_kv_head, j], 0)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if final_chunks > 0:
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * final_scale
T.copy(logsum[:valid_block_H],
glse[bid, hid * valid_block_H:(hid + 1) * valid_block_H, sid])
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output_partial[bid, hid * valid_block_H:(hid + 1) * valid_block_H,
sid, :])
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local = T.alloc_fragment([num_split, 128], dtype)
lse_logsum_local = T.alloc_fragment([128], accum_dtype)
lse_max_local = T.alloc_fragment([128], accum_dtype)
scale_local = T.alloc_fragment([128], accum_dtype)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
for k, j in T.Parallel(num_split, 128):
lse_local[k, j] = glse[bz, by, k]
T.reduce_max(lse_local, lse_max_local, dim=0, clear=True)
for k in T.serial(num_split):
for j in T.Parallel(128):
lse_logsum_local[j] += T.exp2(lse_local[k, j] - lse_max_local[j])
for j in T.Parallel(128):
lse_logsum_local[j] = T.log2(lse_logsum_local[j]) + lse_max_local[j]
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
for j in T.Parallel(128):
scale_local[j] = T.exp2(lse_local[k, j] - lse_logsum_local[j])
# Note: Pay attention to dim and the number of threads in Parallel
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[i]
for i in T.Parallel(dim):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def flashattn_gqa_decode_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
block_table: T.Tensor(shape_bt, dtype="int32"),
cache_seqlen: T.Tensor([batch], dtype="int32"),
softmax_scale: T.float32,
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype),
):
flash_attn_split(Q, K, V, block_table, cache_seqlen, softmax_scale, glse, Output_partial)
combine(glse, Output_partial, Output)
return flashattn_gqa_decode_split
class FlashAttnGQADecodeKVCache(torch.nn.Module):
def __init__(self, heads: int, groups: int, dim: int, max_blocks: int, num_blocks: int, block_size: int):
super().__init__()
self.heads = heads
self.groups = groups
self.dim = dim
self.max_blocks = max_blocks
self.num_blocks = num_blocks
self.block_size = block_size
self.block_N = 128
self.block_H = 64
self.num_split = 8
self.kernel = flashattn(
heads=heads,
groups=groups,
dim=dim,
max_blocks=max_blocks,
num_blocks=T.symbolic("num_blocks"),
block_size=block_size,
batch=T.symbolic("batch"),
block_N=self.block_N,
block_H=self.block_H,
num_split=self.num_split,
num_stages=1,
threads=128)
def forward(self, query, key, value, block_table, cache_seqlen, softmax_scale):
batch_size = query.size(0)
glse = torch.empty(batch_size, self.heads, 8, device=query.device, dtype=query.dtype)
output_partial = torch.empty(batch_size, self.heads, 8, self.dim, device=query.device, dtype=query.dtype)
output = torch.empty(batch_size, self.heads, self.dim, device=query.device, dtype=query.dtype)
self.kernel(query, key, value, block_table, cache_seqlen, softmax_scale, glse, output_partial, output)
return output
def ref_program_fa(query, key, value, block_table, cache_seqlens, softmax_scale, max_seqlen_k, cu_seqlens_q):
from vllm.vllm_flash_attn import flash_attn_varlen_func
out = flash_attn_varlen_func(query, key, value, 1, cu_seqlens_q, max_seqlen_k, seqused_k=cache_seqlens, causal=False, softmax_scale=softmax_scale, block_table=block_table)
return out
def main(max_batch: int = 1,
heads: int = 32,
groups: int = 8,
max_cache_seqlen: int = 8192,
block_size: int = 32,
max_model_len: int = 131072,
dim: int = 128,
run_type: str = "benchmark" # "benchmark", "correctness"
):
qk_flops = 2 * max_batch * heads * max_cache_seqlen * dim
pv_flops = 2 * max_batch * heads * max_cache_seqlen * dim
total_flops = qk_flops + pv_flops
num_blocks = ((max_model_len * max_batch) + block_size-1) // block_size
max_blocks = (max_model_len + block_size-1) // block_size
kernel = FlashAttnGQADecodeKVCache(heads, groups, dim, max_blocks, num_blocks, block_size)
if (run_type == "correctness"):
num_runs = 100
else:
num_runs = 10
lat_f_sum = lat_og_sum = 0.0
cnt_f = cnt_og = 0
for _ in range(num_runs):
if run_type == "correctness":
batch = torch.randint(1, max_batch + 1, (1,)).item()
cache_seqlens = torch.randint(128, max_cache_seqlen, (batch,), device='cuda', dtype=torch.int32)
else:
batch = max_batch
cache_seqlens = torch.full((batch,), max_cache_seqlen, device='cuda', dtype=torch.int32)
q = torch.randn(batch, heads, dim, device="cuda", dtype=torch.float16)
k = torch.randn(num_blocks, block_size, groups, dim, device="cuda", dtype=torch.float16)
v = torch.randn(num_blocks, block_size, groups, dim, device="cuda", dtype=torch.float16)
# Create block_table where each request gets random unique block ids
block_table = torch.randperm(batch * max_blocks, device="cuda", dtype=torch.int32).reshape(batch, max_blocks)
max_seqlen_k = cache_seqlens.max().item()
cu_seqlens_q = torch.zeros(batch + 1, device='cuda', dtype=torch.int32)
cu_seqlens_q[1:] = cache_seqlens.cumsum(dim=0)
if run_type == "correctness":
o_ref = ref_program_fa(q, k, v, block_table, cache_seqlens, (1.0 / dim)**0.5, max_seqlen_k, cu_seqlens_q)
o_og = kernel(q, k, v, block_table, cache_seqlens, (1.0 / dim)**0.5)
assert_similar(o_ref, o_og, name="ref_og", assert_=True, print_=True)
if run_type == "benchmark":
try:
latency_f = do_bench(ref_program_fa, n_warmup=1, n_repeat=5, input_tensors=[q, k, v, block_table, cache_seqlens, (1.0 / dim)**0.5, max_seqlen_k, cu_seqlens_q])
lat_f_sum += latency_f
cnt_f += 1
except Exception:
pass
try:
latency_og = do_bench(kernel, n_warmup=1, n_repeat=5, input_tensors=[q, k, v, block_table, cache_seqlens, (1.0 / dim)**0.5])
lat_og_sum += latency_og
cnt_og += 1
except Exception:
pass
del q, k, v, cache_seqlens
torch.cuda.empty_cache()
if run_type == "benchmark":
avg_f = lat_f_sum / cnt_f if cnt_f > 0 else 0.0
avg_og = lat_og_sum / cnt_og if cnt_og > 0 else 0.0
results = [f"{max_cache_seqlen}"]
tflops_f = total_flops / avg_f * 1e-9 if avg_f > 0 else 0
tflops_og = total_flops / avg_og * 1e-9 if avg_og > 0 else 0
results.extend([f"{avg_f:.2f}", f"{tflops_f:.2f}"])
results.extend([f"{avg_og:.2f}", f"{tflops_og:.2f}"])
print(",".join(results))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--max_batch', type=int, default=4, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--groups', type=int, default=8, help='groups')
parser.add_argument('--max_cache_seqlen', type=int, default=8192, help='max cache sequence length')
parser.add_argument('--block_size', type=int, default=32, help='block size <= 128')
parser.add_argument('--max_model_len', type=int, default=131072, help='max model length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--run_type', type=str, default="benchmark", choices=["benchmark", "correctness"], help='Type of run: benchmark, correctness')
args = parser.parse_args()
main(args.max_batch, args.heads, args.groups, args.max_cache_seqlen, args.block_size, args.max_model_len, args.dim, args.run_type)Traceback
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/dd/workspace/src/kernels/flash_decoding/gqa_decode_paged.py", line 322, in <module>
main(args.max_batch, args.heads, args.groups, args.max_cache_seqlen, args.block_size, args.max_model_len, args.dim, args.run_type)
File "/home/dd/workspace/src/kernels/flash_decoding/gqa_decode_paged.py", line 251, in main
kernel = FlashAttnGQADecodeKVCache(heads, groups, dim, max_blocks, num_blocks, block_size)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dd/workspace/src/kernels/flash_decoding/gqa_decode_paged.py", line 209, in __init__
self.kernel = flashattn(
^^^^^^^^^^
File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/jit/__init__.py", line 414, in __call__
kernel = self.compile(*args, **kwargs, **tune_params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/jit/__init__.py", line 349, in compile
kernel_result = compile(
^^^^^^^^
File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/jit/__init__.py", line 98, in compile
return cached(
^^^^^^^
File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/cache/__init__.py", line 74, in cached
return _dispatch_map[execution_backend].cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/cache/kernel_cache.py", line 204, in cached
kernel = JITKernel(
^^^^^^^^^^
File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/jit/kernel.py", line 137, in __init__
adapter = self._compile_and_create_adapter(func, out_idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/jit/kernel.py", line 242, in _compile_and_create_adapter
artifact = tilelang.lower(
^^^^^^^^^^^^^^^
File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/engine/lower.py", line 270, in lower
mod = OptimizeForTarget(mod, target)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/engine/phase.py", line 205, in OptimizeForTarget
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dd/miniconda3/envs/kascade_vllm/lib/python3.12/site-packages/tilelang/3rdparty/tvm/python/tvm/ir/transform.py", line 167, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "python/tvm_ffi/cython/function.pxi", line 923, in tvm_ffi.core.Function.__call__
File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule) const
File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
File "<unknown>", line 0, in tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
File "<unknown>", line 0, in tvm::tl::software_pipeline::PipelineInjector::Inject(tvm::tir::PrimFunc const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#15}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::BlockRealizeNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#14}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tl::software_pipeline::PipelineInjector::VisitStmt_(tvm::tir::BlockNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::BlockNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#10}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::SeqStmtNode const*)
File "<unknown>", line 0, in tvm::ffi::ObjectPtr<tvm::ffi::Object> tvm::ffi::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::ffi::ObjectPtr<tvm::ffi::Object>, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#15}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::BlockRealizeNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#14}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tl::software_pipeline::PipelineInjector::VisitStmt_(tvm::tir::BlockNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::BlockNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#10}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::SeqStmtNode const*)
File "<unknown>", line 0, in tvm::ffi::ObjectPtr<tvm::ffi::Object> tvm::ffi::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::ffi::ObjectPtr<tvm::ffi::Object>, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#3}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::IfThenElseNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#10}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::SeqStmtNode const*)
File "<unknown>", line 0, in tvm::ffi::ObjectPtr<tvm::ffi::Object> tvm::ffi::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::ffi::ObjectPtr<tvm::ffi::Object>, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#14}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tl::software_pipeline::PipelineInjector::VisitStmt_(tvm::tir::BlockNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::BlockNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#10}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::SeqStmtNode const*)
File "<unknown>", line 0, in tvm::ffi::ObjectPtr<tvm::ffi::Object> tvm::ffi::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::ffi::ObjectPtr<tvm::ffi::Object>, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#1}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::LetStmtNode const*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#10}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::SeqStmtNode const*)
File "<unknown>", line 0, in tvm::ffi::ObjectPtr<tvm::ffi::Object> tvm::ffi::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::ffi::ObjectPtr<tvm::ffi::Object>, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::ffi::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#4}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
File "<unknown>", line 0, in tvm::tl::software_pipeline::PipelineInjector::VisitStmt_(tvm::tir::ForNode const*)
File "<unknown>", line 0, in tvm::runtime::detail::LogFatal::Entry::Finalize()
ValueError: Check failed: src_info.stage <= dst_info.stage (1 vs. 0) : statement with T.block("", no_realize=True):
batch = T.int32()
cache_seqlen = T.Buffer((batch,), "int32", strides=(1,))
bid = T.int32()
num_blocks = T.int32()
K = T.Buffer((num_blocks, 64, 8, 128), "float16", strides=(65536, 1024, 128, 1))
tx = T.int32()
cur_kv_head = T.Buffer((1,), "int32", scope="local.var")
block_table = T.Buffer((batch, 2048), "int32", strides=(2048, 1))
block_table_offset = T.Buffer((1,), "int32", scope="local.var")
T.reads(cache_seqlen[bid], K[0:num_blocks, tx // 16:tx // 16 + 57, cur_kv_head[0], tx % 16 * 8:tx % 16 * 8 + 8], block_table[bid, block_table_offset[0]:block_table_offset[0] + 2], block_table_offset[0], cur_kv_head[0])
K_shared = T.Buffer((1, 2, 16, 512), "float16", scope="shared.dyn")
k = T.int32()
T.writes(K_shared[T.FloorMod(k, 1), tx % 16 // 8, 0:16, tx // 16 * 64 + (tx // 64 + tx % 8 // 4) % 2 * 32 + (tx % 64 // 32 + tx % 4 // 2) % 2 * 16 + (tx % 32 // 16 + tx % 2) % 2 * 8:tx // 16 * 64 + (tx // 64 + tx % 8 // 4) % 2 * 32 + (tx % 64 // 32 + tx % 4 // 2) % 2 * 16 + (tx % 32 // 16 + tx % 2) % 2 * 8 + 8])
for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}):
for vec in T.vectorized(8):
sid = T.int32()
total_chunks = T.int32()
K_shared[T.FloorMod(k, 1), tx % 16 // 8, i, tx // 16 * 64 + (tx // 64 + tx % 8 // 4) % 2 * 32 + (tx % 64 // 32 + tx % 4 // 2) % 2 * 16 + (tx % 32 // 16 + tx % 2) % 2 * 8 + vec] = T.if_then_else(T.min(sid, total_chunks % 8) * 128 + k * 128 + total_chunks // 8 * sid * 128 + i * 8 + tx // 16 < cache_seqlen[bid], K[block_table[bid, i // 8 + block_table_offset[0]], i % 8 * 8 + tx // 16, cur_kv_head[0], tx % 16 * 8 + vec], T.float16(0.0)) in stage 0 cannot depends on statement with T.block("", no_realize=True):
pages_per_blockN = T.Buffer((1,), "int32", scope="local.var")
T.reads(pages_per_blockN[0])
block_table_offset = T.Buffer((1,), "int32", scope="local.var")
T.writes(block_table_offset[0])
sid = T.int32()
total_chunks = T.int32()
k = T.int32()
block_table_offset[0] = (T.min(sid, total_chunks % 8) + total_chunks // 8 * sid + k) * pages_per_blockN[0] in a later stage 1Expected behavior
The T.var are just simple calculation which should not affect the pipeline logic. Given the CUDA without these variables it easy enough to extract these variables for storing common calculations in registers. So, ideally tilelang should generate similar be able to do this as well and not result in error or weird code.
Additional context
No response