Skip to content

Commit 8f87c59

Browse files
author
l30072083
committed
[kernel] Recompilation optimization triggered by triton function parameter optimization
Signed-off-by: l30072083 <liuchengzhuo1@h-partners.com>
1 parent 9d1452c commit 8f87c59

File tree

5 files changed

+24
-42
lines changed

5 files changed

+24
-42
lines changed

vllm_ascend/ops/triton/fla/chunk_delta_h.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
2727
}
2828
)
29-
@triton.jit(do_not_specialize=["T"])
29+
@triton.jit(do_not_specialize=["T","H", "Hg", "K", "V"])
3030
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
3131
k,
3232
v,
@@ -40,10 +40,10 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
4040
chunk_offsets,
4141
h_update,
4242
T,
43-
H: tl.constexpr,
44-
Hg: tl.constexpr,
45-
K: tl.constexpr,
46-
V: tl.constexpr,
43+
H,
44+
Hg,
45+
K,
46+
V,
4747
BT: tl.constexpr,
4848
USE_G: tl.constexpr,
4949
USE_INITIAL_STATE: tl.constexpr,

vllm_ascend/ops/triton/fla/cumsum.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def chunk_local_cumsum_scalar_kernel(
2626
cu_seqlens,
2727
chunk_indices,
2828
T,
29-
B: tl.constexpr,
3029
H: tl.constexpr,
3130
BLOCK_T: tl.constexpr,
3231
REVERSE: tl.constexpr,
@@ -101,7 +100,6 @@ def chunk_local_cumsum_scalar(
101100
cu_seqlens=cu_seqlens,
102101
chunk_indices=block_indices,
103102
T=T,
104-
B=B,
105103
H=H,
106104
BLOCK_T=OPTIM_BLOCK_SIZE,
107105
CHUNK_SIZE=chunk_size,

vllm_ascend/ops/triton/fused_gdn_gating.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
UNIFIED_BUFFER_SIZE = 1572864
1111

1212

13-
@triton.jit
13+
@triton.jit(do_not_specialize=["seq_len", "NUM_HEADS", "NUM_BATCHES", "beta",
14+
"threshold", "ROW_ITER"])
1415
def fused_gdn_gating_kernel(
1516
g,
1617
beta_output,
@@ -19,16 +20,17 @@ def fused_gdn_gating_kernel(
1920
b,
2021
dt_bias,
2122
seq_len,
22-
NUM_HEADS: tl.constexpr,
23-
NUM_BATCHES: tl.constexpr,
24-
beta: tl.constexpr,
25-
threshold: tl.constexpr,
23+
NUM_HEADS,
24+
NUM_BATCHES,
25+
beta,
26+
threshold,
2627
BLK_HEADS: tl.constexpr,
27-
COL_ITER: tl.constexpr,
2828
BLK_BATCHES: tl.constexpr,
29-
ROW_ITER: tl.constexpr,
29+
ROW_ITER,
3030
):
3131
i_b, i_s = tl.program_id(0), tl.program_id(1)
32+
COL_ITER = tl.cdiv(NUM_HEADS, BLK_HEADS)
33+
3234
for row_idx in range(0, ROW_ITER):
3335
batch_off = i_b * ROW_ITER * BLK_BATCHES + row_idx * BLK_BATCHES + tl.arange(0, BLK_BATCHES)
3436

@@ -69,23 +71,11 @@ def fused_gdn_gating_patch(
6971
num_cores = get_vectorcore_num()
7072

7173
BLK_HEADS = 8
72-
COL_ITER = triton.cdiv(num_heads, BLK_HEADS)
73-
74-
elem_size = a.element_size()
75-
max_ub_batches = int((UNIFIED_BUFFER_SIZE * 0.95) / (BLK_HEADS * elem_size))
76-
if batch <= num_cores:
77-
progs = batch
78-
BLK_BATCHES = 1
79-
ROW_ITER = 1
80-
else:
81-
progs = num_cores
82-
FACTOR = 8 * num_heads
83-
calc_blk_batches = (
84-
triton.next_power_of_2(triton.cdiv(int(UNIFIED_BUFFER_SIZE * 0.95), FACTOR * BLK_HEADS * elem_size)) // 2
85-
)
86-
BLK_BATCHES = max(1, min(calc_blk_batches, max_ub_batches, 64))
87-
row_per_core = triton.cdiv(batch, progs)
88-
ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES)
74+
75+
progs = num_cores
76+
row_per_core = triton.cdiv(batch, progs)
77+
BLK_BATCHES = 64
78+
ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES)
8979

9080
g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
9181
beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
@@ -104,7 +94,6 @@ def fused_gdn_gating_patch(
10494
beta,
10595
threshold,
10696
BLK_HEADS=BLK_HEADS,
107-
COL_ITER=COL_ITER,
10897
BLK_BATCHES=BLK_BATCHES,
10998
ROW_ITER=ROW_ITER,
11099
)

vllm_ascend/ops/triton/reject_sample.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,8 @@
2121

2222

2323
def cal_grid_and_block_size(batch_size: int):
24-
vectorcore_num = get_vectorcore_num()
25-
if batch_size <= vectorcore_num:
26-
grid = batch_size
27-
block_size = 1
28-
else:
29-
grid = vectorcore_num
30-
block_size = triton.next_power_of_2(triton.cdiv(batch_size, grid))
24+
grid = batch_size
25+
block_size = 64
3126
return grid, block_size
3227

3328

@@ -82,7 +77,7 @@ def bonus_renew(
8277
tl.store(output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1, bonus_token_id)
8378

8479

85-
@triton.jit(do_not_specialize=["max_spec_len"])
80+
@triton.jit(do_not_specialize=["vec_len", "max_spec_len"])
8681
def rejection_greedy_sample_triton(
8782
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
8883
cu_num_draft_tokens_ptr, # [batch_size]
@@ -196,7 +191,7 @@ def rejection_random_sample_kernel(
196191
)
197192

198193

199-
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
194+
@triton.jit(do_not_specialize=["replace_from", "replace_to", "vec_len"])
200195
def expand_kernel(
201196
output_ptr, # [num_tokens]
202197
input_ptr, # [batch_size]

vllm_ascend/ops/triton/spec_decode/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from vllm.triton_utils import tl, triton
1919

2020

21-
@triton.jit
21+
@triton.jit(do_not_specialize=["num_reqs"])
2222
def prepare_inputs_padded_kernel(
2323
cu_num_draft_tokens_ptr, # [num_reqs]
2424
valid_sampled_tokens_count_ptr, # [num_reqs]

0 commit comments

Comments
 (0)