Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions vllm_ascend/ops/triton/fla/chunk_delta_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
@triton.jit(do_not_specialize=["T", "H", "Hg", "K", "V"])
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
k,
v,
Expand All @@ -40,10 +40,10 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
chunk_offsets,
h_update,
T,
H: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
H,
Hg,
K,
V,
BT: tl.constexpr,
USE_G: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
Expand Down
2 changes: 0 additions & 2 deletions vllm_ascend/ops/triton/fla/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def chunk_local_cumsum_scalar_kernel(
cu_seqlens,
chunk_indices,
T,
B: tl.constexpr,
H: tl.constexpr,
BLOCK_T: tl.constexpr,
REVERSE: tl.constexpr,
Expand Down Expand Up @@ -103,7 +102,6 @@ def chunk_local_cumsum_scalar(
cu_seqlens=cu_seqlens,
chunk_indices=block_indices,
T=T,
B=B,
H=H,
BLOCK_T=OPTIM_BLOCK_SIZE,
CHUNK_SIZE=chunk_size,
Expand Down
38 changes: 13 additions & 25 deletions vllm_ascend/ops/triton/fused_gdn_gating.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
UNIFIED_BUFFER_SIZE = 1572864


@triton.jit
@triton.jit(do_not_specialize=["seq_len", "NUM_HEADS", "NUM_BATCHES", "beta", "threshold", "ROW_ITER"])
def fused_gdn_gating_kernel(
g,
beta_output,
Expand All @@ -19,16 +19,17 @@ def fused_gdn_gating_kernel(
b,
dt_bias,
seq_len,
NUM_HEADS: tl.constexpr,
NUM_BATCHES: tl.constexpr,
beta: tl.constexpr,
threshold: tl.constexpr,
NUM_HEADS,
NUM_BATCHES,
beta,
threshold,
BLK_HEADS: tl.constexpr,
COL_ITER: tl.constexpr,
BLK_BATCHES: tl.constexpr,
ROW_ITER: tl.constexpr,
ROW_ITER,
):
i_b, i_s = tl.program_id(0), tl.program_id(1)
COL_ITER = tl.cdiv(NUM_HEADS, BLK_HEADS)

for row_idx in range(0, ROW_ITER):
batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(0, BLK_BATCHES)

Expand Down Expand Up @@ -69,23 +70,11 @@ def fused_gdn_gating_patch(
num_cores = get_vectorcore_num()

BLK_HEADS = 8
COL_ITER = triton.cdiv(num_heads, BLK_HEADS)

elem_size = a.element_size()
max_ub_batches = int((UNIFIED_BUFFER_SIZE * 0.95) / (BLK_HEADS * elem_size))
if batch <= num_cores:
progs = batch
BLK_BATCHES = 1
ROW_ITER = 1
else:
progs = num_cores
FACTOR = 8 * num_heads
calc_blk_batches = (
triton.next_power_of_2(triton.cdiv(int(UNIFIED_BUFFER_SIZE * 0.95), FACTOR * BLK_HEADS * elem_size)) // 2
)
BLK_BATCHES = max(1, min(calc_blk_batches, max_ub_batches, 64))
row_per_core = triton.cdiv(batch, progs)
ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES)

progs = num_cores
row_per_core = triton.cdiv(batch, progs)
BLK_BATCHES = 64
ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES)

g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
Expand All @@ -104,7 +93,6 @@ def fused_gdn_gating_patch(
beta,
threshold,
BLK_HEADS=BLK_HEADS,
COL_ITER=COL_ITER,
BLK_BATCHES=BLK_BATCHES,
ROW_ITER=ROW_ITER,
)
Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/ops/triton/reject_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def bonus_renew(
tl.store(output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1, bonus_token_id)


@triton.jit(do_not_specialize=["max_spec_len"])
@triton.jit(do_not_specialize=["vec_len", "max_spec_len"])
def rejection_greedy_sample_triton(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
Expand Down Expand Up @@ -196,7 +196,7 @@ def rejection_random_sample_kernel(
)


@triton.jit(do_not_specialize=["replace_from", "replace_to"])
@triton.jit(do_not_specialize=["replace_from", "replace_to", "vec_len"])
def expand_kernel(
output_ptr, # [num_tokens]
input_ptr, # [batch_size]
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/ops/triton/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from vllm.triton_utils import tl, triton


@triton.jit
@triton.jit(do_not_specialize=["num_reqs"])
def prepare_inputs_padded_kernel(
cu_num_draft_tokens_ptr, # [num_reqs]
valid_sampled_tokens_count_ptr, # [num_reqs]
Expand Down
Loading