Skip to content

Commit f3c1b3c

Browse files
Handle max_query_len zero during cudagraph capture
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
1 parent 5a06276 commit f3c1b3c

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

vllm/v1/attention/backends/mla/flashattn_mla.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,12 +229,17 @@ def _forward_decode(
229229
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
230230
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:]
231231

232+
# NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
233+
# kernel uses this to calculate grid dimensions. Ensure it's at least 1
234+
# to prevent invalid grid configuration during graph capture.
235+
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
236+
232237
o = flash_attn_varlen_func(
233238
q=q_pe,
234239
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
235240
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
236241
q_v=q_nope,
237-
max_seqlen_q=attn_metadata.decode.max_query_len,
242+
max_seqlen_q=max_seqlen_q,
238243
cu_seqlens_q=attn_metadata.decode.query_start_loc,
239244
max_seqlen_k=attn_metadata.decode.max_seq_len,
240245
seqused_k=attn_metadata.decode.seq_lens,

0 commit comments

Comments
 (0)