Skip to content

Commit 2332243

Browse files
authored
[V1][CUDA] Full cudagraph support for FlashInfer (#21367)
1 parent 3654847 commit 2332243

File tree

8 files changed

+377
-48
lines changed

8 files changed

+377
-48
lines changed

vllm/v1/attention/backends/flash_attn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from vllm.config import VllmConfig, get_layers_from_vllm_config
2626
from vllm.logger import init_logger
2727
from vllm.utils import cdiv
28-
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
28+
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
29+
AttentionMetadataBuilder,
2930
CommonAttentionMetadata,
3031
get_kv_cache_layout)
3132
from vllm.v1.kv_cache_interface import AttentionSpec
@@ -153,7 +154,9 @@ def _get_sliding_window_configs(
153154

154155
class FlashAttentionMetadataBuilder(
155156
AttentionMetadataBuilder[FlashAttentionMetadata]):
156-
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
157+
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
158+
AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \
159+
else AttentionCGSupport.ALWAYS
157160

158161
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
159162
vllm_config: VllmConfig, device: torch.device):

vllm/v1/attention/backends/flashinfer.py

Lines changed: 323 additions & 34 deletions
Large diffs are not rendered by default.

vllm/v1/attention/backends/mla/flashmla.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
MLACommonImpl,
1919
MLACommonMetadata,
2020
MLACommonMetadataBuilder)
21+
from vllm.v1.attention.backends.utils import AttentionCGSupport
2122
from vllm.v1.kv_cache_interface import AttentionSpec
2223

2324
logger = init_logger(__name__)
@@ -54,7 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
5455

5556

5657
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
57-
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
58+
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
59+
AttentionCGSupport.PURE_DECODE_ONLY
5860

5961
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
6062
vllm_config: VllmConfig, device: torch.device):

vllm/v1/attention/backends/mla/rocm_aiter_mla.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
MLACommonImpl,
1818
MLACommonMetadata,
1919
MLACommonMetadataBuilder)
20+
from vllm.v1.attention.backends.utils import AttentionCGSupport
2021
from vllm.v1.kv_cache_interface import AttentionSpec
2122

2223
# yapf: enable
@@ -64,7 +65,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
6465

6566

6667
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
67-
full_cudagraph_supported: ClassVar[bool] = True # decode only
68+
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
69+
AttentionCGSupport.PURE_DECODE_ONLY
6870

6971
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
7072
vllm_config: VllmConfig, device: torch.device):

vllm/v1/attention/backends/triton_attn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from vllm.logger import init_logger
1919
from vllm.platforms import current_platform
2020
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
21-
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
21+
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
22+
AttentionMetadataBuilder,
2223
CommonAttentionMetadata)
2324
from vllm.v1.kv_cache_interface import AttentionSpec
2425

@@ -57,7 +58,8 @@ class TritonAttentionMetadata:
5758

5859
class TritonAttentionMetadataBuilder(
5960
AttentionMetadataBuilder[TritonAttentionMetadata]):
60-
full_cudagraph_supported: ClassVar[bool] = True
61+
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
62+
AttentionCGSupport.ALWAYS
6163

6264
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
6365
vllm_config: VllmConfig, device: torch.device):

vllm/v1/attention/backends/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import abc
4+
import enum
45
import functools
56
from abc import abstractmethod
67
from dataclasses import dataclass, make_dataclass
@@ -65,9 +66,24 @@ class CommonAttentionMetadata:
6566
M = TypeVar("M")
6667

6768

69+
class AttentionCGSupport(enum.Enum):
70+
""" Constants for the cudagraph support of the attention backend
71+
Here we do not consider the cascade attention, as currently
72+
it is never cudagraph supported."""
73+
74+
NEVER = 0
75+
"""NO cudagraph support"""
76+
PURE_DECODE_ONLY = 1
77+
"""Cudagraph supported for pure decode, need to run without
78+
cudagraph for mixed prefill-decode batches"""
79+
ALWAYS = 2
80+
"""Cudagraph always supported"""
81+
82+
6883
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
6984
# Does this backend/builder support CUDA Graphs for attention.
70-
full_cudagraph_supported: ClassVar[bool] = False
85+
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
86+
AttentionCGSupport.NEVER
7187

7288
@abstractmethod
7389
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],

vllm/v1/worker/gpu_model_runner.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
is_pin_memory_available, round_up, supports_dynamo)
4848
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
4949
from vllm.v1.attention.backends.utils import (
50-
AttentionMetadataBuilder, CommonAttentionMetadata,
50+
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
5151
make_kv_sharing_fast_prefill_attention_metadata,
5252
make_local_attention_virtual_batches)
5353
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@@ -2619,12 +2619,22 @@ def _initialize_single_attn_backend(
26192619
self.device,
26202620
)
26212621

2622-
if (self.full_cuda_graph
2623-
and not attn_metadata_builder_i.full_cudagraph_supported):
2624-
raise ValueError(
2625-
f"Full CUDAGraph not supported for "
2626-
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
2627-
f"full_cuda_graph or use a different attention backend.")
2622+
if self.full_cuda_graph:
2623+
if attn_metadata_builder_i.attn_cudagraph_support == \
2624+
AttentionCGSupport.NEVER:
2625+
raise ValueError(f"Full CUDAGraph not supported for "
2626+
f"{attn_backend_i.__name__}. Turn off "
2627+
f"CompilationConfig.full_cuda_graph or use a "
2628+
f" different attention backend.")
2629+
if attn_metadata_builder_i.attn_cudagraph_support == \
2630+
AttentionCGSupport.PURE_DECODE_ONLY:
2631+
# Limit the max cudagraph size to the max number of
2632+
# sequences for pure decode only cudagraph backend,
2633+
# whose max_query_len is 1.
2634+
self.cudagraph_batch_sizes = [
2635+
size for size in self.cudagraph_batch_sizes
2636+
if size <= self.scheduler_config.max_num_seqs
2637+
]
26282638
return attn_backend_i, attn_metadata_builder_i
26292639

26302640
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:

vllm/v1/worker/gpu_worker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,16 @@ def compile_or_warm_up_model(self) -> None:
321321
if get_pp_group().is_last_rank:
322322
max_num_reqs = min(self.scheduler_config.max_num_seqs,
323323
self.scheduler_config.max_num_batched_tokens)
324+
# activate building attn_metadata for this dummy run to avoid
325+
# potential illegal memory access for full cudagraph relay.
326+
attn_cudagraph = self.compilation_config.full_cuda_graph and\
327+
not self.model_config.enforce_eager
324328

325329
# We skip EPLB here since we don't want to record dummy metrics
326330
hidden_states, last_hidden_states = \
327331
self.model_runner._dummy_run(
328332
num_tokens=max_num_reqs,
333+
capture_attn_cudagraph=attn_cudagraph,
329334
skip_eplb=True,
330335
)
331336
if self.model_runner.is_pooling_model:

0 commit comments

Comments
 (0)