diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index f0ea1d653c3e..13552edab87b 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import ClassVar, Optional, Union import torch from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla @@ -12,13 +12,20 @@ MLACommonBackend, MLACommonImpl, MLACommonMetadata, + MLACommonMetadataBuilder, ) +from vllm.v1.attention.backends.utils import AttentionCGSupport logger = init_logger(__name__) FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 +class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): + # enable full CUDA Graph support for decode-only capture + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + + class FlashInferMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: @@ -28,6 +35,10 @@ def get_name() -> str: def get_impl_cls() -> type["FlashInferMLAImpl"]: return FlashInferMLAImpl + @staticmethod + def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]: + return FlashInferMLAMetadataBuilder + g_fi_workspace = torch.zeros( FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,