Skip to content
Open
96 changes: 34 additions & 62 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,32 +52,9 @@
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE
forced_attn_backend: Optional[_Backend] = None


def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
"""
Force all attention operations to use a specified backend.

Passing `None` for the argument re-enables automatic
backend selection.,

Arguments:

* attn_backend: backend selection (None to revert to auto)
"""
global forced_attn_backend
forced_attn_backend = attn_backend


def get_global_forced_attn_backend() -> Optional[_Backend]:
"""
Get the currently-forced choice of attention backend,
or None if auto-selection is currently enabled.
"""
return forced_attn_backend
# NOTE: The global forced backend mechanism has been removed.
# To override the attention backend, modify vllm_config.attention_config.backend
# using get_current_vllm_config().attention_config.backend = "BACKEND_NAME"


@dataclass(frozen=True)
Expand Down Expand Up @@ -177,32 +154,27 @@
) -> type[AttentionBackend]:
# Check whether a particular choice of backend was
# previously forced.
#
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
# Check the config (which may come from CLI arg, env var, or runtime override)
from vllm.config import get_current_vllm_config

selected_backend = None
backend_by_global_setting: Optional[_Backend] = get_global_forced_attn_backend()
if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting
else:
# Check the environment variable and override if specified
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
if backend_by_env_var.endswith("_VLLM_V1"):
logger.warning(
"The suffix '_VLLM_V1' in the environment variable "
"%s is no longer necessary as V0 backends have been "
"deprecated. Please remove this suffix from your "
"environment variable setting.",
STR_BACKEND_ENV_VAR,
)
backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1")
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
raise ValueError(
f"Invalid attention backend: '{backend_by_env_var}'. "
f"Valid backends are: {list(_Backend.__members__.keys())}"
)
vllm_config = get_current_vllm_config()
backend_by_config: Optional[str] = vllm_config.attention_config.backend
if backend_by_config is not None:
if backend_by_config.endswith("_VLLM_V1"):
logger.warning(
"The suffix '_VLLM_V1' in the attention backend "
"is no longer necessary as V0 backends have been "
"deprecated. Please remove this suffix from your "
"backend setting."
)
backend_by_config = backend_by_config.removesuffix("_VLLM_V1")
selected_backend = _Backend.backend_name_to_enum(backend_by_config)

Check failure on line 172 in vllm/attention/selector.py

View workflow job for this annotation

GitHub Actions / pre-commit

"type[_Backend]" has no attribute "backend_name_to_enum" [attr-defined]

Check failure on line 172 in vllm/attention/selector.py

View workflow job for this annotation

GitHub Actions / pre-commit

"type[_Backend]" has no attribute "backend_name_to_enum" [attr-defined]

Check failure on line 172 in vllm/attention/selector.py

View workflow job for this annotation

GitHub Actions / pre-commit

"type[_Backend]" has no attribute "backend_name_to_enum" [attr-defined]

Check failure on line 172 in vllm/attention/selector.py

View workflow job for this annotation

GitHub Actions / pre-commit

"type[_Backend]" has no attribute "backend_name_to_enum" [attr-defined]
if selected_backend is None:
raise ValueError(
f"Invalid attention backend: '{backend_by_config}'. "
f"Valid backends are: {list(_Backend.__members__.keys())}"
)

# get device-specific attn_backend
attention_cls = current_platform.get_attn_backend_cls(
Expand All @@ -228,29 +200,29 @@
attn_backend: _Backend,
) -> Generator[None, None, None]:
"""
Globally force a vLLM attention backend override within a
context manager, reverting the global attention backend
override to its prior state upon exiting the context
manager.
Temporarily override the attention backend within a context manager,
reverting to the original backend upon exiting.

Arguments:

* attn_backend: attention backend to force
* attn_backend: attention backend to use

Returns:

* Generator
"""
from vllm.config import get_current_vllm_config

# Save the current state of the global backend override (if any)
original_value = get_global_forced_attn_backend()
# Save the current backend from config
vllm_config = get_current_vllm_config()
original_value = vllm_config.attention_config.backend

# Globally force the new backend override
global_force_attn_backend(attn_backend)
# Override the backend in config
vllm_config.attention_config.backend = str(attn_backend.name)

# Yield control back to the enclosed code block
try:
yield
finally:
# Revert the original global backend override, if any
global_force_attn_backend(original_value)
# Revert the original backend
vllm_config.attention_config.backend = original_value
12 changes: 7 additions & 5 deletions vllm/attention/utils/fa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional

from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform

Expand Down Expand Up @@ -42,10 +41,13 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2
)

# 2. override if passed by environment
if envs.VLLM_FLASH_ATTN_VERSION is not None:
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
fa_version = envs.VLLM_FLASH_ATTN_VERSION
# 2. override if passed by environment or config
from vllm.config import get_current_vllm_config

vllm_config = get_current_vllm_config()
if vllm_config.attention_config.flash_attn_version is not None:
assert vllm_config.attention_config.flash_attn_version in [2, 3]
fa_version = vllm_config.attention_config.flash_attn_version

# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:
Expand Down
3 changes: 3 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm.config.attention import AttentionConfig
from vllm.config.cache import (
BlockSize,
CacheConfig,
Expand Down Expand Up @@ -57,6 +58,8 @@
)

__all__ = [
# From vllm.config.attention
"AttentionConfig",
# From vllm.config.cache
"BlockSize",
"CacheConfig",
Expand Down
77 changes: 77 additions & 0 deletions vllm/config/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import hashlib
from typing import Any, Optional

from pydantic import ConfigDict
from pydantic.dataclasses import dataclass

from vllm.config.utils import config


@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class AttentionConfig:
"""Configuration for attention mechanisms in vLLM."""

backend: Optional[str] = None
"""Attention backend to use. If None, will be selected automatically.
Example options: FLASH_ATTN, XFORMERS, FLASHINFER, etc."""

use_triton_flash_attn: bool = True
"""Whether to use triton flash attention."""

flash_attn_version: Optional[int] = None
"""Force vllm to use a specific flash-attention version (2 or 3).
Only valid when using the flash-attention backend."""

v1_use_prefill_decode_attention: bool = False
"""Use separate prefill and decode kernels for V1 attention instead of
the unified triton kernel."""

use_aiter_unified_attention: bool = False
"""Use AITER triton unified attention for V1 attention."""

flash_attn_max_num_splits_for_cuda_graph: int = 32
"""Flash Attention max number splits for cuda graph decode."""

use_cudnn_prefill: bool = False
"""Whether to use cudnn prefill."""

use_trtllm_attention: Optional[bool] = None
"""If set to True/False, use or don't use the TRTLLM attention backend
in flashinfer. If None, auto-detect the attention backend in flashinfer."""

disable_flashinfer_prefill: bool = False
"""Whether to disable flashinfer prefill."""

flashinfer_disable_q_quantization: bool = False
"""If set, when using fp8 kv, do not quantize Q to fp8."""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
factors: list[Any] = [
self.backend,
self.use_triton_flash_attn,
self.flash_attn_version,
self.v1_use_prefill_decode_attention,
self.use_aiter_unified_attention,
self.flash_attn_max_num_splits_for_cuda_graph,
self.use_cudnn_prefill,
self.use_trtllm_attention,
self.disable_flashinfer_prefill,
self.flashinfer_disable_q_quantization,
]
Comment on lines +64 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The manual listing of fields in compute_hash is fragile. If a new field is added to AttentionConfig in the future, a developer might forget to update this list, leading to incorrect cache hashes. This can cause subtle and hard-to-debug issues.

To make this more robust, you could programmatically collect the fields. This ensures that any new field is automatically included. Since all current fields seem to affect the computation graph, iterating over all fields is appropriate. This can be done without extra imports by using the __dataclass_fields__ attribute.

        factors = [getattr(self, f) for f in self.__dataclass_fields__]

hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
19 changes: 12 additions & 7 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from vllm.config.utils import assert_hashable, config, getattr_iter
from vllm.logger import init_logger
from vllm.platforms import current_platform

if TYPE_CHECKING:
from vllm.config.attention import AttentionConfig
from vllm.transformers_utils.config import (
ConfigFormat,
get_config,
Expand Down Expand Up @@ -280,6 +283,9 @@ class ModelConfig:
"""
override_attention_dtype: Optional[str] = None
"""Override dtype for attention"""
attention_config: Optional["AttentionConfig"] = None
"""Attention configuration. If not specified, will be read from environment
variables."""
logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None
"""One or more logits processors' fully-qualified class names or class
definitions"""
Expand Down Expand Up @@ -444,13 +450,11 @@ def __post_init__(

self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)

if (
(backend := envs.VLLM_ATTENTION_BACKEND)
and backend == "FLASHINFER"
and find_spec("flashinfer") is None
):
# Early validation for FLASHINFER backend
backend = self.attention_config.backend if self.attention_config else None
if backend == "FLASHINFER" and find_spec("flashinfer") is None:
raise ValueError(
"VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer "
"attention_backend is set to FLASHINFER, but flashinfer "
"module was not found. See "
"https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501
"for instructions on how to install it."
Expand Down Expand Up @@ -637,7 +641,8 @@ def _task_to_convert(task: TaskOption) -> ConvertType:
not self.disable_sliding_window
and is_interleaved(self.hf_text_config)
and not envs.VLLM_USE_V1
and (backend := envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER")
and backend is not None
and backend in ("XFORMERS", "FLASHINFER")
):
logger.warning_once(
"%s has interleaved attention, which is currently not "
Expand Down
7 changes: 7 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid

from .attention import AttentionConfig
from .cache import CacheConfig
from .compilation import CompilationConfig, CompilationLevel, CUDAGraphMode
from .device import DeviceConfig
Expand Down Expand Up @@ -68,6 +69,8 @@ class VllmConfig:
"""Device configuration."""
load_config: LoadConfig = field(default_factory=LoadConfig)
"""Load configuration."""
attention_config: AttentionConfig = field(default_factory=AttentionConfig)
"""Attention configuration."""
lora_config: Optional[LoRAConfig] = None
"""LoRA configuration."""
speculative_config: Optional[SpeculativeConfig] = None
Expand Down Expand Up @@ -153,6 +156,10 @@ def compute_hash(self) -> str:
vllm_factors.append(self.load_config.compute_hash())
else:
vllm_factors.append("None")
if self.attention_config:
vllm_factors.append(self.attention_config.compute_hash())
else:
vllm_factors.append("None")
Comment on lines +159 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The attention_config field in VllmConfig is initialized with a default_factory and is not Optional. Therefore, self.attention_config will always be an instance of AttentionConfig and the if self.attention_config: check will always evaluate to true. This makes the else branch unreachable (dead code) and the conditional check redundant. You can simplify this by directly appending the hash.

        vllm_factors.append(self.attention_config.compute_hash())

if self.lora_config:
vllm_factors.append(self.lora_config.compute_hash())
# LoRA creates static buffers based on max_num_batched_tokens.
Expand Down
Loading
Loading