Skip to content

Commit 6040e0b

Browse files
LucasWilkinsongemini-code-assist[bot]mgoin
authored andcommitted
[BugFix] Fix FI accuracy issue when used for MLA prefill (#26063)
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: mgoin <[email protected]> Signed-off-by: simon-mo <[email protected]>
1 parent 05bf0c5 commit 6040e0b

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,13 +1211,18 @@ def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q,
12111211
k, v, return_softmax_lse):
12121212
assert isinstance(prefill, FlashInferPrefillMetadata)
12131213
assert prefill.prefill_main is not None
1214-
return prefill.prefill_main.run(
1214+
ret = prefill.prefill_main.run(
12151215
q=q,
12161216
k=k,
12171217
v=v,
12181218
return_lse=return_softmax_lse,
12191219
)
12201220

1221+
if isinstance(ret, tuple):
1222+
# Convert from (q_len, num_heads) to (num_heads, q_len)
1223+
return ret[0], ret[1].transpose(0, 1).contiguous()
1224+
return ret
1225+
12211226
def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata,
12221227
q, k, v, return_softmax_lse):
12231228
assert isinstance(prefill, CudnnPrefillMetadata)
@@ -1260,12 +1265,14 @@ def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata,
12601265
def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata,
12611266
chunk_idx: int, q, k, v):
12621267
assert isinstance(prefill, FlashInferPrefillMetadata)
1263-
return prefill.prefill_chunks[chunk_idx].run(
1268+
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
12641269
q=q,
12651270
k=k,
12661271
v=v,
12671272
return_lse=True,
12681273
)
1274+
# Convert from (q_len, num_heads) to (num_heads, q_len)
1275+
return attn_out, lse.transpose(0, 1).contiguous()
12691276

12701277
def _run_prefill_context_chunk_cudnn(self,
12711278
prefill: MLACommonPrefillMetadata,

0 commit comments

Comments
 (0)