Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion vllm/v1/attention/backends/mla/flashinfer_mla.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional, Union
from typing import ClassVar, Optional, Union

import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
Expand All @@ -12,13 +12,20 @@
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
)
from vllm.v1.attention.backends.utils import AttentionCGSupport

logger = init_logger(__name__)

FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024


class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH


class FlashInferMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
Expand All @@ -28,6 +35,10 @@ def get_name() -> str:
def get_impl_cls() -> type["FlashInferMLAImpl"]:
return FlashInferMLAImpl

@staticmethod
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
return FlashInferMLAMetadataBuilder


g_fi_workspace = torch.zeros(
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
Expand Down