Skip to content

Commit 21a696b

Browse files
authored
[None][feat] Optimize the q3n decode kernel with IO read (NVIDIA#11344)
Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
1 parent 959306c commit 21a696b

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def fused_sigmoid_gating_delta_rule_update(
177177
B, T, H, K, V = *k.shape, v.shape[-1]
178178
HV = v.shape[2]
179179
N = B if cu_seqlens is None else len(cu_seqlens) - 1
180-
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
180+
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
181181
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
182182
assert NK == 1, "NK > 1 is not supported yet"
183183
num_stages = 3

0 commit comments

Comments
 (0)