Skip to content
Open
2 changes: 1 addition & 1 deletion vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"USE_G": lambda args: args["g_cumsum"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
@triton.jit(do_not_specialize=["T", "B"])
def chunk_scaled_dot_kkt_fwd_kernel(
k,
beta, # [H, B, T]
Expand Down
8 changes: 4 additions & 4 deletions vllm_ascend/ops/triton/fla/sigmoid_gating.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def div_normal(x, y):
"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
}
)
@triton.jit(do_not_specialize=["N", "T"])
@triton.jit(do_not_specialize=["scale", "N", "T", "B"])
def fused_recurrent_gated_delta_rule_fwd_kernel(
q,
k,
Expand All @@ -53,9 +53,9 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
ssm_state_indices,
num_accepted_tokens,
scale,
N: tl.constexpr, # num of sequences
T: tl.constexpr, # num of tokens
B: tl.constexpr,
N, # num of sequences
T, # num of tokens
B,
H: tl.constexpr,
HV: tl.constexpr,
K: tl.constexpr,
Expand Down
12 changes: 6 additions & 6 deletions vllm_ascend/ops/triton/fla/solve_tril.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@


@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
@triton.jit(do_not_specialize=["T", "H"])
def solve_tril_16x16_kernel(
A,
Ad,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
H,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
LARGE_BLOCK_T: tl.constexpr,
Expand Down Expand Up @@ -134,15 +134,15 @@ def solve_tril_16x16_kernel(


@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
@triton.jit(do_not_specialize=["T", "H"])
def merge_16x16_to_32x32_inverse_kernel(
A,
Ad,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
H,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
Expand Down Expand Up @@ -198,15 +198,15 @@ def merge_16x16_to_32x32_inverse_kernel(


@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T"])
@triton.jit(do_not_specialize=["T", "H"])
def merge_16x16_to_64x64_inverse_kernel(
A,
Ad,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
H,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def split_qkv_rmsnorm_mrope_kernel(
q_size: tl.constexpr,
kv_size: tl.constexpr,
eps: tl.constexpr,
mrope_section_t: tl.constexpr,
mrope_section_h: tl.constexpr,
mrope_section_w: tl.constexpr,
mrope_section_t,
mrope_section_h,
mrope_section_w,
Comment on lines +51 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To align with the goal of this PR to reduce kernel recompilations, mrope_section_t, mrope_section_h, and mrope_section_w should be added to the do_not_specialize list in the @triton.jit decorator.

You have correctly removed tl.constexpr from these parameters, but without adding them to do_not_specialize, Triton may still recompile the kernel when their values change.

Please update the decorator on line 28 to include them:

@triton.jit(
    do_not_specialize=[
        "num_tokens", "front_core_num", "num_tokens_each_front_core",
        "num_tokens_each_tail_core", "mrope_section_t", "mrope_section_h",
        "mrope_section_w"
    ]
)

has_bias: tl.constexpr,
is_interleaved: tl.constexpr,
rope_dim: tl.constexpr,
Expand Down
33 changes: 23 additions & 10 deletions vllm_ascend/ops/triton/mamba/causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,20 @@ def extract_last_width(x, start_loc, width):
return x[:, indices].permute(1, 0, 2)


@triton.jit
@triton.jit(
do_not_specialize=[
"batch",
"state_len",
"num_cache_lines",
"stride_x_seq",
"stride_x_token",
"stride_conv_state_seq",
"stride_conv_state_tok",
"stride_state_indices",
"stride_o_seq",
"stride_o_token",
]
)
def _causal_conv1d_update_kernel_npu_tiled(
# Pointers
x_ptr, # (batch, dim, seqlen) OR (num_tokens, dim) for varlen
Expand All @@ -173,21 +186,21 @@ def _causal_conv1d_update_kernel_npu_tiled(
batch: tl.int32,
dim: tl.constexpr,
seqlen: tl.constexpr, # max seqlen for varlen, or exact seqlen
state_len: tl.constexpr, # effective state_len computed in wrapper
num_cache_lines: tl.constexpr,
state_len, # effective state_len computed in wrapper
num_cache_lines,
# Strides
stride_x_seq: tl.constexpr,
stride_x_seq,
stride_x_dim: tl.constexpr,
stride_x_token: tl.constexpr,
stride_x_token,
stride_w_dim: tl.constexpr,
stride_w_width: tl.constexpr,
stride_conv_state_seq: tl.constexpr,
stride_conv_state_seq,
stride_conv_state_dim: tl.constexpr,
stride_conv_state_tok: tl.constexpr,
stride_state_indices: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_conv_state_tok,
stride_state_indices,
stride_o_seq,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
stride_o_token,
# others
pad_slot_id: tl.constexpr,
# Meta
Expand Down
Loading