File tree Expand file tree Collapse file tree 1 file changed +12
-1
lines changed
vllm/v1/attention/backends/mla Expand file tree Collapse file tree 1 file changed +12
-1
lines changed Original file line number Diff line number Diff line change 1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
4
- from typing import Optional , Union
4
+ from typing import ClassVar , Optional , Union
5
5
6
6
import torch
7
7
from flashinfer .decode import trtllm_batch_decode_with_kv_cache_mla
12
12
MLACommonBackend ,
13
13
MLACommonImpl ,
14
14
MLACommonMetadata ,
15
+ MLACommonMetadataBuilder ,
15
16
)
17
+ from vllm .v1 .attention .backends .utils import AttentionCGSupport
16
18
17
19
logger = init_logger (__name__ )
18
20
19
21
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
20
22
21
23
24
+ class FlashInferMLAMetadataBuilder (MLACommonMetadataBuilder [MLACommonMetadata ]):
25
+ # enable full CUDA Graph support for decode-only capture
26
+ cudagraph_support : ClassVar [AttentionCGSupport ] = AttentionCGSupport .UNIFORM_BATCH
27
+
28
+
22
29
class FlashInferMLABackend (MLACommonBackend ):
23
30
@staticmethod
24
31
def get_name () -> str :
@@ -28,6 +35,10 @@ def get_name() -> str:
28
35
def get_impl_cls () -> type ["FlashInferMLAImpl" ]:
29
36
return FlashInferMLAImpl
30
37
38
+ @staticmethod
39
+ def get_builder_cls () -> type ["FlashInferMLAMetadataBuilder" ]:
40
+ return FlashInferMLAMetadataBuilder
41
+
31
42
32
43
g_fi_workspace = torch .zeros (
33
44
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE ,
You can’t perform that action at this time.
0 commit comments