Skip to content

Commit f77df94

Browse files
authored
[Perf] Add decode full-graph support to FlashInfer-MLA backend (#26313)
Signed-off-by: Benjamin Chislett <[email protected]>
1 parent f231e5b commit f77df94

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

vllm/v1/attention/backends/mla/flashinfer_mla.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from typing import Optional, Union
4+
from typing import ClassVar, Optional, Union
55

66
import torch
77
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
@@ -12,13 +12,20 @@
1212
MLACommonBackend,
1313
MLACommonImpl,
1414
MLACommonMetadata,
15+
MLACommonMetadataBuilder,
1516
)
17+
from vllm.v1.attention.backends.utils import AttentionCGSupport
1618

1719
logger = init_logger(__name__)
1820

1921
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
2022

2123

24+
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
25+
# enable full CUDA Graph support for decode-only capture
26+
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
27+
28+
2229
class FlashInferMLABackend(MLACommonBackend):
2330
@staticmethod
2431
def get_name() -> str:
@@ -28,6 +35,10 @@ def get_name() -> str:
2835
def get_impl_cls() -> type["FlashInferMLAImpl"]:
2936
return FlashInferMLAImpl
3037

38+
@staticmethod
39+
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
40+
return FlashInferMLAMetadataBuilder
41+
3142

3243
g_fi_workspace = torch.zeros(
3344
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,

0 commit comments

Comments
 (0)