Skip to content
Open
3 changes: 3 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm.config.attention import AttentionConfig
from vllm.config.cache import (
BlockSize,
CacheConfig,
Expand Down Expand Up @@ -57,6 +58,8 @@
)

__all__ = [
# From vllm.config.attention
"AttentionConfig",
# From vllm.config.cache
"BlockSize",
"CacheConfig",
Expand Down
77 changes: 77 additions & 0 deletions vllm/config/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import hashlib
from typing import Any, Optional

from pydantic import ConfigDict
from pydantic.dataclasses import dataclass

from vllm.config.utils import config


@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class AttentionConfig:
"""Configuration for attention mechanisms in vLLM."""

backend: Optional[str] = None
"""Attention backend to use. If None, will be selected automatically.
Example options: FLASH_ATTN, XFORMERS, FLASHINFER, etc."""

use_triton_flash_attn: bool = True
"""Whether to use triton flash attention."""

flash_attn_version: Optional[int] = None
"""Force vllm to use a specific flash-attention version (2 or 3).
Only valid when using the flash-attention backend."""

v1_use_prefill_decode_attention: bool = False
"""Use separate prefill and decode kernels for V1 attention instead of
the unified triton kernel."""

use_aiter_unified_attention: bool = False
"""Use AITER triton unified attention for V1 attention."""

flash_attn_max_num_splits_for_cuda_graph: int = 32
"""Flash Attention max number splits for cuda graph decode."""

use_cudnn_prefill: bool = False
"""Whether to use cudnn prefill."""

use_trtllm_attention: Optional[bool] = None
"""If set to True/False, use or don't use the TRTLLM attention backend
in flashinfer. If None, auto-detect the attention backend in flashinfer."""

disable_flashinfer_prefill: bool = False
"""Whether to disable flashinfer prefill."""

flashinfer_disable_q_quantization: bool = False
"""If set, when using fp8 kv, do not quantize Q to fp8."""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: list[Any] = [
self.backend,
self.use_triton_flash_attn,
self.flash_attn_version,
self.v1_use_prefill_decode_attention,
self.use_aiter_unified_attention,
self.flash_attn_max_num_splits_for_cuda_graph,
self.use_cudnn_prefill,
self.use_trtllm_attention,
self.disable_flashinfer_prefill,
self.flashinfer_disable_q_quantization,
]
Comment on lines +64 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The manual listing of fields in compute_hash is fragile. If a new field is added to AttentionConfig in the future, a developer might forget to update this list, leading to incorrect cache hashes. This can cause subtle and hard-to-debug issues.

To make this more robust, you could programmatically collect the fields. This ensures that any new field is automatically included. Since all current fields seem to affect the computation graph, iterating over all fields is appropriate. This can be done without extra imports by using the __dataclass_fields__ attribute.

        factors = [getattr(self, f) for f in self.__dataclass_fields__]

hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
7 changes: 7 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid

from .attention import AttentionConfig
from .cache import CacheConfig
from .compilation import CompilationConfig, CompilationLevel, CUDAGraphMode
from .device import DeviceConfig
Expand Down Expand Up @@ -68,6 +69,8 @@ class VllmConfig:
"""Device configuration."""
load_config: LoadConfig = field(default_factory=LoadConfig)
"""Load configuration."""
attention_config: AttentionConfig = field(default_factory=AttentionConfig)
"""Attention configuration."""
lora_config: Optional[LoRAConfig] = None
"""LoRA configuration."""
speculative_config: Optional[SpeculativeConfig] = None
Expand Down Expand Up @@ -153,6 +156,10 @@ def compute_hash(self) -> str:
vllm_factors.append(self.load_config.compute_hash())
else:
vllm_factors.append("None")
if self.attention_config:
vllm_factors.append(self.attention_config.compute_hash())
else:
vllm_factors.append("None")
Comment on lines +159 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The attention_config field in VllmConfig is initialized with a default_factory and is not Optional. Therefore, self.attention_config will always be an instance of AttentionConfig and the if self.attention_config: check will always evaluate to true. This makes the else branch unreachable (dead code) and the conditional check redundant. You can simplify this by directly appending the hash.

        vllm_factors.append(self.attention_config.compute_hash())

if self.lora_config:
vllm_factors.append(self.lora_config.compute_hash())
# LoRA creates static buffers based on max_num_batched_tokens.
Expand Down
65 changes: 64 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import vllm.envs as envs
from vllm.config import (
AttentionConfig,
BlockSize,
CacheConfig,
CacheDType,
Expand Down Expand Up @@ -494,6 +495,7 @@ class EngineArgs:
)
model_impl: str = ModelConfig.model_impl
override_attention_dtype: str = ModelConfig.override_attention_dtype
attention_backend: Optional[str] = AttentionConfig.backend

calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
Expand Down Expand Up @@ -655,6 +657,20 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--pt-load-map-location", **load_kwargs["pt_load_map_location"]
)

# Attention arguments
attention_group = parser.add_argument_group(
title="AttentionConfig",
description=AttentionConfig.__doc__,
)
attention_group.add_argument(
"--attention-backend",
type=str,
default=EngineArgs.attention_backend,
help="Attention backend to use. If not specified, will be selected "
"automatically. Example options: FLASH_ATTN, XFORMERS, FLASHINFER, "
"FLASHMLA, etc.",
)

# Structured outputs arguments
structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
structured_outputs_group = parser.add_argument_group(
Expand Down Expand Up @@ -1201,6 +1217,50 @@ def create_speculative_config(
)
return SpeculativeConfig(**self.speculative_config)

def create_attention_config(self) -> AttentionConfig:
"""Create attention configuration.
This method reads from environment variables to maintain backward
compatibility with existing deployments. All attention-related
environment variables are respected:
- VLLM_ATTENTION_BACKEND (deprecated, use --attention-backend CLI arg)
- VLLM_USE_TRITON_FLASH_ATTN
- VLLM_FLASH_ATTN_VERSION
- VLLM_V1_USE_PREFILL_DECODE_ATTENTION
- VLLM_USE_AITER_UNIFIED_ATTENTION
- VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
- VLLM_USE_CUDNN_PREFILL
- VLLM_USE_TRTLLM_ATTENTION
- VLLM_DISABLE_FLASHINFER_PREFILL
- VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
"""

# Warn if VLLM_ATTENTION_BACKEND env var is used instead of CLI arg
if envs.is_set("VLLM_ATTENTION_BACKEND") and self.attention_backend is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The deprecation warning for VLLM_ATTENTION_BACKEND is only issued when the --attention-backend CLI argument is not provided. If a user sets both, the CLI argument silently overrides the environment variable. This can be confusing, as the user might not realize their environment variable is being ignored. It's better to always warn when the deprecated environment variable is set to avoid this confusion.

Suggested change
if envs.is_set("VLLM_ATTENTION_BACKEND") and self.attention_backend is None:
if envs.is_set("VLLM_ATTENTION_BACKEND"):

logger.warning(
"Using VLLM_ATTENTION_BACKEND environment variable is deprecated "
"and will be removed in a future release. "
"Please use --attention-backend CLI argument instead."
)

# Handle backend: prefer CLI arg, fall back to env var
backend = self.attention_backend
if backend is None:
backend = envs.VLLM_ATTENTION_BACKEND

return AttentionConfig(
backend=backend,
Comment on lines +1241 to +1255

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Propagate --attention-backend CLI value to runtime

The new create_attention_config reads the --attention-backend CLI argument but only returns an AttentionConfig object; it never updates envs.VLLM_ATTENTION_BACKEND. The rest of the codebase (e.g. attention selector and platform-specific backends) still relies on the global env variable to decide which backend to load. As a result, specifying --attention-backend has no effect: all backend selection logic still sees None (or the old env value) and behaves as if the argument was never given. This makes the advertised CLI option a no-op and bypasses existing compatibility checks tied to the env variable. The CLI value should be written back into envs.VLLM_ATTENTION_BACKEND or the downstream logic should read from AttentionConfig instead.

Useful? React with 👍 / 👎.

use_triton_flash_attn=envs.VLLM_USE_TRITON_FLASH_ATTN,
flash_attn_version=envs.VLLM_FLASH_ATTN_VERSION,
v1_use_prefill_decode_attention=envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION,
use_aiter_unified_attention=envs.VLLM_USE_AITER_UNIFIED_ATTENTION,
flash_attn_max_num_splits_for_cuda_graph=envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH,
use_cudnn_prefill=envs.VLLM_USE_CUDNN_PREFILL,
use_trtllm_attention=envs.VLLM_USE_TRTLLM_ATTENTION,
disable_flashinfer_prefill=envs.VLLM_DISABLE_FLASHINFER_PREFILL,
flashinfer_disable_q_quantization=envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION,
)

def create_engine_config(
self,
usage_context: Optional[UsageContext] = None,
Expand Down Expand Up @@ -1543,15 +1603,18 @@ def create_engine_config(
collect_detailed_traces=self.collect_detailed_traces,
)

attention_config = self.create_attention_config()

config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
load_config=load_config,
attention_config=attention_config,
lora_config=lora_config,
speculative_config=speculative_config,
load_config=load_config,
structured_outputs_config=self.structured_outputs_config,
observability_config=observability_config,
compilation_config=self.compilation_config,
Expand Down