Skip to content

Commit d4fd276

Browse files
[Bugfix][Attention] Fix FlashInfer MLA block size logic (#24692)
Signed-off-by: Matthew Bonanni <[email protected]>
1 parent 7a70a71 commit d4fd276

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

vllm/platforms/cuda.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
146146
# required block_size.
147147
use_flashmla = False
148148
use_cutlass_mla = False
149+
use_flashinfer_mla = False
149150

150151
if envs.VLLM_ATTENTION_BACKEND is None:
151152
# Default case
@@ -164,6 +165,8 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
164165
use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
165166
use_cutlass_mla = (
166167
envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA")
168+
use_flashinfer_mla = (
169+
envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA")
167170

168171
from vllm.attention.ops.flashmla import is_flashmla_supported
169172
if use_flashmla and is_flashmla_supported()[0] \
@@ -176,6 +179,11 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
176179
cache_config.block_size = 128
177180
logger.info("Forcing kv cache block size to 128 for "
178181
"CUTLASS_MLA backend.")
182+
if use_flashinfer_mla and cache_config.block_size not in [32, 64]:
183+
cache_config.block_size = 64
184+
logger.info(
185+
"Forcing kv cache block size to 64 for FlashInferMLA "
186+
"backend.")
179187

180188
# lazy import to avoid circular import
181189
from vllm.config import CUDAGraphMode
@@ -228,8 +236,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
228236
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
229237
selected_backend is None and cls.is_device_capability(100)
230238
and block_size == 128)
231-
use_flashinfermla = (selected_backend == _Backend.FLASHINFER_MLA
232-
and cls.has_device_capability(100))
239+
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
240+
selected_backend is None and cls.is_device_capability(100)
241+
and block_size in [32, 64])
233242
use_flashmla = selected_backend in [
234243
_Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1
235244
] or (selected_backend is None and is_flashmla_supported()[0])

0 commit comments

Comments
 (0)