@@ -146,6 +146,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
146
146
# required block_size.
147
147
use_flashmla = False
148
148
use_cutlass_mla = False
149
+ use_flashinfer_mla = False
149
150
150
151
if envs .VLLM_ATTENTION_BACKEND is None :
151
152
# Default case
@@ -164,6 +165,8 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
164
165
use_flashmla = (envs .VLLM_ATTENTION_BACKEND == "FLASHMLA" )
165
166
use_cutlass_mla = (
166
167
envs .VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" )
168
+ use_flashinfer_mla = (
169
+ envs .VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA" )
167
170
168
171
from vllm .attention .ops .flashmla import is_flashmla_supported
169
172
if use_flashmla and is_flashmla_supported ()[0 ] \
@@ -176,6 +179,11 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
176
179
cache_config .block_size = 128
177
180
logger .info ("Forcing kv cache block size to 128 for "
178
181
"CUTLASS_MLA backend." )
182
+ if use_flashinfer_mla and cache_config .block_size not in [32 , 64 ]:
183
+ cache_config .block_size = 64
184
+ logger .info (
185
+ "Forcing kv cache block size to 64 for FlashInferMLA "
186
+ "backend." )
179
187
180
188
# lazy import to avoid circular import
181
189
from vllm .config import CUDAGraphMode
@@ -228,8 +236,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
228
236
use_cutlassmla = selected_backend == _Backend .CUTLASS_MLA or (
229
237
selected_backend is None and cls .is_device_capability (100 )
230
238
and block_size == 128 )
231
- use_flashinfermla = (selected_backend == _Backend .FLASHINFER_MLA
232
- and cls .has_device_capability (100 ))
239
+ use_flashinfermla = selected_backend == _Backend .FLASHINFER_MLA or (
240
+ selected_backend is None and cls .is_device_capability (100 )
241
+ and block_size in [32 , 64 ])
233
242
use_flashmla = selected_backend in [
234
243
_Backend .FLASHMLA , _Backend .FLASHMLA_VLLM_V1
235
244
] or (selected_backend is None and is_flashmla_supported ()[0 ])
0 commit comments