diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index effd35444d54..d2a73238d450 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -52,32 +52,9 @@ def get_env_variable_attn_backend() -> Optional[_Backend]: # a backend based on system & workload configuration # (default behavior if this variable is None) # -# THIS SELECTION TAKES PRECEDENCE OVER THE -# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE -forced_attn_backend: Optional[_Backend] = None - - -def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: - """ - Force all attention operations to use a specified backend. - - Passing `None` for the argument re-enables automatic - backend selection., - - Arguments: - - * attn_backend: backend selection (None to revert to auto) - """ - global forced_attn_backend - forced_attn_backend = attn_backend - - -def get_global_forced_attn_backend() -> Optional[_Backend]: - """ - Get the currently-forced choice of attention backend, - or None if auto-selection is currently enabled. - """ - return forced_attn_backend +# NOTE: The global forced backend mechanism has been removed. +# To override the attention backend, modify vllm_config.attention_config.backend +# using get_current_vllm_config().attention_config.backend = "BACKEND_NAME" @dataclass(frozen=True) @@ -177,32 +154,27 @@ def _cached_get_attn_backend( ) -> type[AttentionBackend]: # Check whether a particular choice of backend was # previously forced. - # - # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND - # ENVIRONMENT VARIABLE. + # Check the config (which may come from CLI arg, env var, or runtime override) + from vllm.config import get_current_vllm_config + selected_backend = None - backend_by_global_setting: Optional[_Backend] = get_global_forced_attn_backend() - if backend_by_global_setting is not None: - selected_backend = backend_by_global_setting - else: - # Check the environment variable and override if specified - backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND - if backend_by_env_var is not None: - if backend_by_env_var.endswith("_VLLM_V1"): - logger.warning( - "The suffix '_VLLM_V1' in the environment variable " - "%s is no longer necessary as V0 backends have been " - "deprecated. Please remove this suffix from your " - "environment variable setting.", - STR_BACKEND_ENV_VAR, - ) - backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1") - selected_backend = backend_name_to_enum(backend_by_env_var) - if selected_backend is None: - raise ValueError( - f"Invalid attention backend: '{backend_by_env_var}'. " - f"Valid backends are: {list(_Backend.__members__.keys())}" - ) + vllm_config = get_current_vllm_config() + backend_by_config: Optional[str] = vllm_config.attention_config.backend + if backend_by_config is not None: + if backend_by_config.endswith("_VLLM_V1"): + logger.warning( + "The suffix '_VLLM_V1' in the attention backend " + "is no longer necessary as V0 backends have been " + "deprecated. Please remove this suffix from your " + "backend setting." + ) + backend_by_config = backend_by_config.removesuffix("_VLLM_V1") + selected_backend = _Backend.backend_name_to_enum(backend_by_config) + if selected_backend is None: + raise ValueError( + f"Invalid attention backend: '{backend_by_config}'. " + f"Valid backends are: {list(_Backend.__members__.keys())}" + ) # get device-specific attn_backend attention_cls = current_platform.get_attn_backend_cls( @@ -228,29 +200,29 @@ def global_force_attn_backend_context_manager( attn_backend: _Backend, ) -> Generator[None, None, None]: """ - Globally force a vLLM attention backend override within a - context manager, reverting the global attention backend - override to its prior state upon exiting the context - manager. + Temporarily override the attention backend within a context manager, + reverting to the original backend upon exiting. Arguments: - * attn_backend: attention backend to force + * attn_backend: attention backend to use Returns: * Generator """ + from vllm.config import get_current_vllm_config - # Save the current state of the global backend override (if any) - original_value = get_global_forced_attn_backend() + # Save the current backend from config + vllm_config = get_current_vllm_config() + original_value = vllm_config.attention_config.backend - # Globally force the new backend override - global_force_attn_backend(attn_backend) + # Override the backend in config + vllm_config.attention_config.backend = str(attn_backend.name) # Yield control back to the enclosed code block try: yield finally: - # Revert the original global backend override, if any - global_force_attn_backend(original_value) + # Revert the original backend + vllm_config.attention_config.backend = original_value diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index e13afd46ee96..7158b10ee6f1 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional -from vllm import envs from vllm.logger import init_logger from vllm.platforms import current_platform @@ -42,10 +41,13 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: 3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2 ) - # 2. override if passed by environment - if envs.VLLM_FLASH_ATTN_VERSION is not None: - assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3] - fa_version = envs.VLLM_FLASH_ATTN_VERSION + # 2. override if passed by environment or config + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + if vllm_config.attention_config.flash_attn_version is not None: + assert vllm_config.attention_config.flash_attn_version in [2, 3] + fa_version = vllm_config.attention_config.flash_attn_version # 3. fallback for unsupported combinations if device_capability.major == 10 and fa_version == 3: diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 7c5052c822f8..2201dd5898f3 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.config.attention import AttentionConfig from vllm.config.cache import ( BlockSize, CacheConfig, @@ -57,6 +58,8 @@ ) __all__ = [ + # From vllm.config.attention + "AttentionConfig", # From vllm.config.cache "BlockSize", "CacheConfig", diff --git a/vllm/config/attention.py b/vllm/config/attention.py new file mode 100644 index 000000000000..e860da983145 --- /dev/null +++ b/vllm/config/attention.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +from typing import Any, Optional + +from pydantic import ConfigDict +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class AttentionConfig: + """Configuration for attention mechanisms in vLLM.""" + + backend: Optional[str] = None + """Attention backend to use. If None, will be selected automatically. + Example options: FLASH_ATTN, XFORMERS, FLASHINFER, etc.""" + + use_triton_flash_attn: bool = True + """Whether to use triton flash attention.""" + + flash_attn_version: Optional[int] = None + """Force vllm to use a specific flash-attention version (2 or 3). + Only valid when using the flash-attention backend.""" + + v1_use_prefill_decode_attention: bool = False + """Use separate prefill and decode kernels for V1 attention instead of + the unified triton kernel.""" + + use_aiter_unified_attention: bool = False + """Use AITER triton unified attention for V1 attention.""" + + flash_attn_max_num_splits_for_cuda_graph: int = 32 + """Flash Attention max number splits for cuda graph decode.""" + + use_cudnn_prefill: bool = False + """Whether to use cudnn prefill.""" + + use_trtllm_attention: Optional[bool] = None + """If set to True/False, use or don't use the TRTLLM attention backend + in flashinfer. If None, auto-detect the attention backend in flashinfer.""" + + disable_flashinfer_prefill: bool = False + """Whether to disable flashinfer prefill.""" + + flashinfer_disable_q_quantization: bool = False + """If set, when using fp8 kv, do not quantize Q to fp8.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [ + self.backend, + self.use_triton_flash_attn, + self.flash_attn_version, + self.v1_use_prefill_decode_attention, + self.use_aiter_unified_attention, + self.flash_attn_max_num_splits_for_cuda_graph, + self.use_cudnn_prefill, + self.use_trtllm_attention, + self.disable_flashinfer_prefill, + self.flashinfer_disable_q_quantization, + ] + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + return hash_str diff --git a/vllm/config/model.py b/vllm/config/model.py index 146ace9782b9..79733a9ef10c 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -29,6 +29,9 @@ from vllm.config.utils import assert_hashable, config, getattr_iter from vllm.logger import init_logger from vllm.platforms import current_platform + +if TYPE_CHECKING: + from vllm.config.attention import AttentionConfig from vllm.transformers_utils.config import ( ConfigFormat, get_config, @@ -280,6 +283,9 @@ class ModelConfig: """ override_attention_dtype: Optional[str] = None """Override dtype for attention""" + attention_config: Optional["AttentionConfig"] = None + """Attention configuration. If not specified, will be read from environment + variables.""" logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None """One or more logits processors' fully-qualified class names or class definitions""" @@ -444,13 +450,11 @@ def __post_init__( self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer) - if ( - (backend := envs.VLLM_ATTENTION_BACKEND) - and backend == "FLASHINFER" - and find_spec("flashinfer") is None - ): + # Early validation for FLASHINFER backend + backend = self.attention_config.backend if self.attention_config else None + if backend == "FLASHINFER" and find_spec("flashinfer") is None: raise ValueError( - "VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer " + "attention_backend is set to FLASHINFER, but flashinfer " "module was not found. See " "https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501 "for instructions on how to install it." @@ -637,7 +641,8 @@ def _task_to_convert(task: TaskOption) -> ConvertType: not self.disable_sliding_window and is_interleaved(self.hf_text_config) and not envs.VLLM_USE_V1 - and (backend := envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER") + and backend is not None + and backend in ("XFORMERS", "FLASHINFER") ): logger.warning_once( "%s has interleaved attention, which is currently not " diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index b5856958ce2e..dce8a05f0098 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -20,6 +20,7 @@ from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.utils import random_uuid +from .attention import AttentionConfig from .cache import CacheConfig from .compilation import CompilationConfig, CompilationLevel, CUDAGraphMode from .device import DeviceConfig @@ -68,6 +69,8 @@ class VllmConfig: """Device configuration.""" load_config: LoadConfig = field(default_factory=LoadConfig) """Load configuration.""" + attention_config: AttentionConfig = field(default_factory=AttentionConfig) + """Attention configuration.""" lora_config: Optional[LoRAConfig] = None """LoRA configuration.""" speculative_config: Optional[SpeculativeConfig] = None @@ -153,6 +156,10 @@ def compute_hash(self) -> str: vllm_factors.append(self.load_config.compute_hash()) else: vllm_factors.append("None") + if self.attention_config: + vllm_factors.append(self.attention_config.compute_hash()) + else: + vllm_factors.append("None") if self.lora_config: vllm_factors.append(self.lora_config.compute_hash()) # LoRA creates static buffers based on max_num_batched_tokens. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a94ef598f2de..02d3fbd4ae06 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -31,6 +31,7 @@ import vllm.envs as envs from vllm.config import ( + AttentionConfig, BlockSize, CacheConfig, CacheDType, @@ -494,6 +495,7 @@ class EngineArgs: ) model_impl: str = ModelConfig.model_impl override_attention_dtype: str = ModelConfig.override_attention_dtype + attention_backend: Optional[str] = AttentionConfig.backend calculate_kv_scales: bool = CacheConfig.calculate_kv_scales mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype @@ -655,6 +657,20 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--pt-load-map-location", **load_kwargs["pt_load_map_location"] ) + # Attention arguments + attention_group = parser.add_argument_group( + title="AttentionConfig", + description=AttentionConfig.__doc__, + ) + attention_group.add_argument( + "--attention-backend", + type=str, + default=EngineArgs.attention_backend, + help="Attention backend to use. If not specified, will be selected " + "automatically. Example options: FLASH_ATTN, XFORMERS, FLASHINFER, " + "FLASHMLA, etc.", + ) + # Structured outputs arguments structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig) structured_outputs_group = parser.add_argument_group( @@ -1041,7 +1057,9 @@ def from_cli_args(cls, args: argparse.Namespace): ) return engine_args - def create_model_config(self) -> ModelConfig: + def create_model_config( + self, attention_config: Optional[AttentionConfig] = None + ) -> ModelConfig: # gguf file needs a specific model loader and doesn't use hf_repo if check_gguf_file(self.model): self.quantization = self.load_format = "gguf" @@ -1133,6 +1151,7 @@ def create_model_config(self) -> ModelConfig: logits_processors=self.logits_processors, video_pruning_rate=self.video_pruning_rate, io_processor_plugin=self.io_processor_plugin, + attention_config=attention_config, ) def validate_tensorizer_args(self): @@ -1201,6 +1220,50 @@ def create_speculative_config( ) return SpeculativeConfig(**self.speculative_config) + def create_attention_config(self) -> AttentionConfig: + """Create attention configuration. + + This method reads from environment variables to maintain backward + compatibility with existing deployments. All attention-related + environment variables are respected: + - VLLM_ATTENTION_BACKEND (deprecated, use --attention-backend CLI arg) + - VLLM_USE_TRITON_FLASH_ATTN + - VLLM_FLASH_ATTN_VERSION + - VLLM_V1_USE_PREFILL_DECODE_ATTENTION + - VLLM_USE_AITER_UNIFIED_ATTENTION + - VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + - VLLM_USE_CUDNN_PREFILL + - VLLM_USE_TRTLLM_ATTENTION + - VLLM_DISABLE_FLASHINFER_PREFILL + - VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION + """ + + # Warn if VLLM_ATTENTION_BACKEND env var is used instead of CLI arg + if envs.is_set("VLLM_ATTENTION_BACKEND") and self.attention_backend is None: + logger.warning( + "Using VLLM_ATTENTION_BACKEND environment variable is deprecated " + "and will be removed in a future release. " + "Please use --attention-backend CLI argument instead." + ) + + # Handle backend: prefer CLI arg, fall back to env var + backend = self.attention_backend + if backend is None: + backend = envs.VLLM_ATTENTION_BACKEND + + return AttentionConfig( + backend=backend, + use_triton_flash_attn=envs.VLLM_USE_TRITON_FLASH_ATTN, + flash_attn_version=envs.VLLM_FLASH_ATTN_VERSION, + v1_use_prefill_decode_attention=envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION, + use_aiter_unified_attention=envs.VLLM_USE_AITER_UNIFIED_ATTENTION, + flash_attn_max_num_splits_for_cuda_graph=envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH, + use_cudnn_prefill=envs.VLLM_USE_CUDNN_PREFILL, + use_trtllm_attention=envs.VLLM_USE_TRTLLM_ATTENTION, + disable_flashinfer_prefill=envs.VLLM_DISABLE_FLASHINFER_PREFILL, + flashinfer_disable_q_quantization=envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION, + ) + def create_engine_config( self, usage_context: Optional[UsageContext] = None, @@ -1223,7 +1286,10 @@ def create_engine_config( device_config = DeviceConfig(device=cast(Device, current_platform.device_type)) - model_config = self.create_model_config() + # Create AttentionConfig first so ModelConfig can use it + attention_config = self.create_attention_config() + + model_config = self.create_model_config(attention_config=attention_config) self.model = model_config.model self.tokenizer = model_config.tokenizer @@ -1543,15 +1609,19 @@ def create_engine_config( collect_detailed_traces=self.collect_detailed_traces, ) + # Note: attention_config was already created earlier in this method + # (before creating model_config) so that ModelConfig can use it + config = VllmConfig( model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, scheduler_config=scheduler_config, device_config=device_config, + load_config=load_config, + attention_config=attention_config, lora_config=lora_config, speculative_config=speculative_config, - load_config=load_config, structured_outputs_config=self.structured_outputs_config, observability_config=observability_config, compilation_config=self.compilation_config, @@ -1624,11 +1694,12 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "XFORMERS", "ROCM_ATTN", ] - if ( - envs.is_set("VLLM_ATTENTION_BACKEND") - and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS - ): - name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}" + # Get backend from CLI arg or env var + backend = self.attention_backend + if backend is None: + backend = envs.VLLM_ATTENTION_BACKEND + if backend is not None and backend not in V1_BACKENDS: + name = f"VLLM_ATTENTION_BACKEND={backend}" _raise_or_fallback(feature_name=name, recommend_to_remove=True) return False diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 20568e0d6c51..b0ae3a9815b1 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -128,23 +128,23 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: use_cutlass_mla = False use_flashinfer_mla = False - if envs.VLLM_ATTENTION_BACKEND is None: + attention_backend = vllm_config.attention_config.backend + if attention_backend is None: # Default case if cls.is_device_capability(100): # Blackwell => Force CutlassMLA. use_cutlass_mla = True - # TODO: This does not work, because the - # global_force_attn_backend_context_manager is not set. - # See vllm/attention/selector.py:_cached_get_attn_backend - envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA" + # Set the backend in AttentionConfig so it's used during + # backend selection + vllm_config.attention_config.backend = "CUTLASS_MLA" else: # Not Blackwell use_flashmla = True else: # Forced case - use_flashmla = envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" - use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" - use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA" + use_flashmla = attention_backend == "FLASHMLA" + use_cutlass_mla = attention_backend == "CUTLASS_MLA" + use_flashinfer_mla = attention_backend == "FLASHINFER_MLA" from vllm.attention.ops.flashmla import is_flashmla_supported @@ -481,8 +481,11 @@ def device_count(cls) -> int: def is_kv_cache_dtype_supported( cls, kv_cache_dtype: str, model_config: "ModelConfig" ) -> bool: + from vllm.config import get_current_vllm_config + fp8_attention = kv_cache_dtype.startswith("fp8") - attention_backend = envs.VLLM_ATTENTION_BACKEND + vllm_config = get_current_vllm_config() + attention_backend = vllm_config.attention_config.backend supported = False if model_config is not None and model_config.use_mla: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 80e7b849c0ed..46f8fabf8f25 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -276,6 +276,9 @@ def get_attn_backend_cls( ) if envs.VLLM_USE_V1: + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): logger.info("Using Flash Attention backend on V1 engine.") return ( @@ -283,8 +286,11 @@ def get_attn_backend_cls( "rocm_aiter_fa.AiterFlashAttentionBackend" ) elif ( - (envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION) - or envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION + ( + envs.VLLM_ROCM_USE_AITER + and vllm_config.attention_config.use_aiter_unified_attention + ) + or vllm_config.attention_config.v1_use_prefill_decode_attention or selected_backend == _Backend.ROCM_ATTN ): # rocm specific backend, with aiter and/or diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 2f2f3ab8b9d9..06e8f6db367c 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -83,9 +83,12 @@ def is_kv_cache_dtype_supported( Check if the kv_cache_dtype is supported. XPU only support fp8 kv cache with triton backend. """ + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() if ( envs.is_set("VLLM_ATTENTION_BACKEND") - and envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN" + and vllm_config.attention_config.backend == "TRITON_ATTN" ): return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"] diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 1d707d56daba..79122dd274c1 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -215,7 +215,12 @@ def force_use_trtllm_attention() -> bool | None: return ``True`` if TRTLLM attention is forced to be used, return ``False`` if TRTLLM attention is forced to be not used. """ - return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION) + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + return _force_use_trtllm_attention( + vllm_config.attention_config.use_trtllm_attention + ) def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool: @@ -434,8 +439,11 @@ def flashinfer_scaled_fp8_mm( @functools.cache def flashinfer_disable_q_quantization() -> bool: - """Cache result which only depends on the environment""" - return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION + """Cache result which only depends on the attention config""" + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + return vllm_config.attention_config.flashinfer_disable_q_quantization __all__ = [ diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index bb3dcddba3e9..1e683773d17f 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -8,7 +8,6 @@ import numpy as np import torch -from vllm import envs from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, @@ -200,6 +199,7 @@ def __init__( self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config self.compilation_config = vllm_config.compilation_config + self.attention_config = vllm_config.attention_config self.num_heads_q = self.model_config.get_num_attention_heads( self.parallel_config @@ -233,7 +233,9 @@ def __init__( # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are # pre-allocated during capture. - self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + self.max_num_splits = ( + self.attention_config.flash_attn_max_num_splits_for_cuda_graph + ) # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f7ec18f5e9f6..82a467383f8d 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -425,18 +425,24 @@ def __post_init__(self): def use_flashinfer_prefill() -> bool: # For blackwell default to flashinfer prefill if it's available since # it is faster than FA2. + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() return ( - not envs.VLLM_DISABLE_FLASHINFER_PREFILL + not vllm_config.attention_config.disable_flashinfer_prefill and flashinfer_available - and not envs.VLLM_USE_CUDNN_PREFILL + and not vllm_config.attention_config.use_cudnn_prefill and current_platform.is_device_capability(100) ) def use_cudnn_prefill() -> bool: + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() return ( flashinfer_available - and envs.VLLM_USE_CUDNN_PREFILL + and vllm_config.attention_config.use_cudnn_prefill and current_platform.is_device_capability(100) and has_nvidia_artifactory() ) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index c0c2dbe1f961..0e3b75ae4d1f 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -6,7 +6,6 @@ import torch -from vllm import envs from vllm.attention.backends.abstract import ( AttentionLayer, AttentionType, @@ -105,7 +104,9 @@ def __init__( # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are # pre-allocated during capture. - self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + self.max_num_splits = ( + vllm_config.attention_config.flash_attn_max_num_splits_for_cuda_graph + ) # TODO(lucas): Until we add support for the DCP custom masking we need # to restrict decodes to q_len == 1 when DCP is enabled. diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 3b6718c48d09..9a894458a792 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -5,7 +5,6 @@ import torch -from vllm import envs from vllm.attention.backends.abstract import ( AttentionLayer, AttentionType, @@ -87,7 +86,10 @@ def __init__( "TritonMLA V1 with FP8 KV cache not yet supported" ) - self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + self.use_triton_flash_attn = vllm_config.attention_config.use_triton_flash_attn self.triton_fa_func = triton_attention if HAS_TRITON else None def _flash_attn_varlen_diff_headdims_rocm( diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 4c24770aa22c..11f7233e3c52 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -79,6 +79,7 @@ def __init__( super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.block_size = kv_cache_spec.block_size + self.attention_config = vllm_config.attention_config model_config = vllm_config.model_config self.num_heads_q = model_config.get_num_attention_heads( @@ -96,7 +97,7 @@ def build_for_cudagraph_capture( # slow, so here we set it to 1. attn_metadata.seq_lens.fill_(1) - if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION: + if self.attention_config.v1_use_prefill_decode_attention: # Here we set the query start locs to 0. This is to # cover up an invalid memory access in the prefix_prefil kernel # that we run into during graph capture (#25985) @@ -216,7 +217,13 @@ def use_aiter_unified_attention() -> bool: """Check if aiter unified attention should be used.""" # VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set # to 1 as default - return envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + return ( + envs.VLLM_ROCM_USE_AITER + and vllm_config.attention_config.use_aiter_unified_attention + ) class RocmAttentionImpl(AttentionImpl): @@ -267,8 +274,13 @@ def __init__( "RocmAttentionImpl" ) + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() self.fp8_dtype = current_platform.fp8_dtype() - self.force_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION + self.force_prefill_decode_attn = ( + vllm_config.attention_config.v1_use_prefill_decode_attention + ) if not self.force_prefill_decode_attn: # If not using prefill decode attention, we use the Triton diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 003c7253e553..3a5477b0b315 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -446,7 +446,10 @@ def infer_global_hyperparameters( global_params = param_sets[0] # trtllm attention doesn't need global hyper params so disable the check - if not envs.VLLM_USE_TRTLLM_ATTENTION: + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + if not vllm_config.attention_config.use_trtllm_attention: for params in param_sets: if params.window_left != global_params.window_left: raise ValueError(