Skip to content

Commit d6661c0

Browse files
HarpsealCCl30072083
andauthored
[v0.18.0][kernel] Recompilation optimization triggered by triton function parameter optimization (vllm-project#7647)
### What this PR does / why we need it? Some parameters of Triton operators are unnecessarily modified with the "constexpr" modifier. When these parameters change, recompilation is triggered, which significantly affects the model performance. Therefore, these parameters need to be rectified. - vLLM version: v0.17.0 - vLLM main: vllm-project/vllm@8b63257 Signed-off-by: HarpSealCC [844291270@qq.com](mailto:844291270@qq.com) Signed-off-by: l30072083 <liuchengzhuo1@h-partners.com> Co-authored-by: l30072083 <liuchengzhuo1@h-partners.com>
1 parent d781902 commit d6661c0

File tree

5 files changed

+21
-35
lines changed

5 files changed

+21
-35
lines changed

vllm_ascend/ops/triton/fla/chunk_delta_h.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
2727
}
2828
)
29-
@triton.jit(do_not_specialize=["T"])
29+
@triton.jit(do_not_specialize=["T", "H", "Hg", "K", "V"])
3030
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
3131
k,
3232
v,
@@ -40,10 +40,10 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
4040
chunk_offsets,
4141
h_update,
4242
T,
43-
H: tl.constexpr,
44-
Hg: tl.constexpr,
45-
K: tl.constexpr,
46-
V: tl.constexpr,
43+
H,
44+
Hg,
45+
K,
46+
V,
4747
BT: tl.constexpr,
4848
USE_G: tl.constexpr,
4949
USE_INITIAL_STATE: tl.constexpr,

vllm_ascend/ops/triton/fla/cumsum.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def chunk_local_cumsum_scalar_kernel(
2626
cu_seqlens,
2727
chunk_indices,
2828
T,
29-
B: tl.constexpr,
3029
H: tl.constexpr,
3130
BLOCK_T: tl.constexpr,
3231
REVERSE: tl.constexpr,
@@ -103,7 +102,6 @@ def chunk_local_cumsum_scalar(
103102
cu_seqlens=cu_seqlens,
104103
chunk_indices=block_indices,
105104
T=T,
106-
B=B,
107105
H=H,
108106
BLOCK_T=OPTIM_BLOCK_SIZE,
109107
CHUNK_SIZE=chunk_size,

vllm_ascend/ops/triton/fused_gdn_gating.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
UNIFIED_BUFFER_SIZE = 1572864
1111

1212

13-
@triton.jit
13+
@triton.jit(do_not_specialize=["seq_len", "NUM_HEADS", "NUM_BATCHES", "beta", "threshold", "ROW_ITER"])
1414
def fused_gdn_gating_kernel(
1515
g,
1616
beta_output,
@@ -19,16 +19,17 @@ def fused_gdn_gating_kernel(
1919
b,
2020
dt_bias,
2121
seq_len,
22-
NUM_HEADS: tl.constexpr,
23-
NUM_BATCHES: tl.constexpr,
24-
beta: tl.constexpr,
25-
threshold: tl.constexpr,
22+
NUM_HEADS,
23+
NUM_BATCHES,
24+
beta,
25+
threshold,
2626
BLK_HEADS: tl.constexpr,
27-
COL_ITER: tl.constexpr,
2827
BLK_BATCHES: tl.constexpr,
29-
ROW_ITER: tl.constexpr,
28+
ROW_ITER,
3029
):
3130
i_b, i_s = tl.program_id(0), tl.program_id(1)
31+
COL_ITER = tl.cdiv(NUM_HEADS, BLK_HEADS)
32+
3233
for row_idx in range(0, ROW_ITER):
3334
batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(0, BLK_BATCHES)
3435

@@ -69,23 +70,11 @@ def fused_gdn_gating_patch(
6970
num_cores = get_vectorcore_num()
7071

7172
BLK_HEADS = 8
72-
COL_ITER = triton.cdiv(num_heads, BLK_HEADS)
73-
74-
elem_size = a.element_size()
75-
max_ub_batches = int((UNIFIED_BUFFER_SIZE * 0.95) / (BLK_HEADS * elem_size))
76-
if batch <= num_cores:
77-
progs = batch
78-
BLK_BATCHES = 1
79-
ROW_ITER = 1
80-
else:
81-
progs = num_cores
82-
FACTOR = 8 * num_heads
83-
calc_blk_batches = (
84-
triton.next_power_of_2(triton.cdiv(int(UNIFIED_BUFFER_SIZE * 0.95), FACTOR * BLK_HEADS * elem_size)) // 2
85-
)
86-
BLK_BATCHES = max(1, min(calc_blk_batches, max_ub_batches, 64))
87-
row_per_core = triton.cdiv(batch, progs)
88-
ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES)
73+
74+
progs = num_cores
75+
row_per_core = triton.cdiv(batch, progs)
76+
BLK_BATCHES = 64
77+
ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES)
8978

9079
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
9180
beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
@@ -104,7 +93,6 @@ def fused_gdn_gating_patch(
10493
beta,
10594
threshold,
10695
BLK_HEADS=BLK_HEADS,
107-
COL_ITER=COL_ITER,
10896
BLK_BATCHES=BLK_BATCHES,
10997
ROW_ITER=ROW_ITER,
11098
)

vllm_ascend/ops/triton/reject_sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def bonus_renew(
8282
tl.store(output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1, bonus_token_id)
8383

8484

85-
@triton.jit(do_not_specialize=["max_spec_len"])
85+
@triton.jit(do_not_specialize=["vec_len", "max_spec_len"])
8686
def rejection_greedy_sample_triton(
8787
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
8888
cu_num_draft_tokens_ptr, # [batch_size]
@@ -196,7 +196,7 @@ def rejection_random_sample_kernel(
196196
)
197197

198198

199-
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
199+
@triton.jit(do_not_specialize=["replace_from", "replace_to", "vec_len"])
200200
def expand_kernel(
201201
output_ptr, # [num_tokens]
202202
input_ptr, # [batch_size]

vllm_ascend/ops/triton/spec_decode/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from vllm.triton_utils import tl, triton
1919

2020

21-
@triton.jit
21+
@triton.jit(do_not_specialize=["num_reqs"])
2222
def prepare_inputs_padded_kernel(
2323
cu_num_draft_tokens_ptr, # [num_reqs]
2424
valid_sampled_tokens_count_ptr, # [num_reqs]

0 commit comments

Comments
 (0)