-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[Attention][UX] Add AttentionConfig and change attention backend to CLI argument #26315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
fceaa69
4037da2
cc5d942
08aa82b
6933f2c
ce87bd4
6a25027
e1d524d
7e35756
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
import hashlib | ||
from typing import Any, Optional | ||
|
||
from pydantic import ConfigDict | ||
from pydantic.dataclasses import dataclass | ||
|
||
from vllm.config.utils import config | ||
|
||
|
||
@config | ||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) | ||
class AttentionConfig: | ||
"""Configuration for attention mechanisms in vLLM.""" | ||
|
||
backend: Optional[str] = None | ||
"""Attention backend to use. If None, will be selected automatically. | ||
Example options: FLASH_ATTN, XFORMERS, FLASHINFER, etc.""" | ||
|
||
use_triton_flash_attn: bool = True | ||
"""Whether to use triton flash attention.""" | ||
|
||
flash_attn_version: Optional[int] = None | ||
"""Force vllm to use a specific flash-attention version (2 or 3). | ||
Only valid when using the flash-attention backend.""" | ||
|
||
v1_use_prefill_decode_attention: bool = False | ||
"""Use separate prefill and decode kernels for V1 attention instead of | ||
the unified triton kernel.""" | ||
|
||
use_aiter_unified_attention: bool = False | ||
"""Use AITER triton unified attention for V1 attention.""" | ||
|
||
flash_attn_max_num_splits_for_cuda_graph: int = 32 | ||
"""Flash Attention max number splits for cuda graph decode.""" | ||
|
||
use_cudnn_prefill: bool = False | ||
"""Whether to use cudnn prefill.""" | ||
|
||
use_trtllm_attention: Optional[bool] = None | ||
"""If set to True/False, use or don't use the TRTLLM attention backend | ||
in flashinfer. If None, auto-detect the attention backend in flashinfer.""" | ||
|
||
disable_flashinfer_prefill: bool = False | ||
"""Whether to disable flashinfer prefill.""" | ||
|
||
flashinfer_disable_q_quantization: bool = False | ||
"""If set, when using fp8 kv, do not quantize Q to fp8.""" | ||
|
||
def compute_hash(self) -> str: | ||
""" | ||
WARNING: Whenever a new field is added to this config, | ||
ensure that it is included in the factors list if | ||
it affects the computation graph. | ||
Provide a hash that uniquely identifies all the configs | ||
that affect the structure of the computation | ||
graph from input ids/embeddings to the final hidden states, | ||
excluding anything before input ids/embeddings and after | ||
the final hidden states. | ||
""" | ||
factors: list[Any] = [ | ||
self.backend, | ||
self.use_triton_flash_attn, | ||
self.flash_attn_version, | ||
self.v1_use_prefill_decode_attention, | ||
self.use_aiter_unified_attention, | ||
self.flash_attn_max_num_splits_for_cuda_graph, | ||
self.use_cudnn_prefill, | ||
self.use_trtllm_attention, | ||
self.disable_flashinfer_prefill, | ||
self.flashinfer_disable_q_quantization, | ||
] | ||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() | ||
return hash_str |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
from vllm.transformers_utils.runai_utils import is_runai_obj_uri | ||
from vllm.utils import random_uuid | ||
|
||
from .attention import AttentionConfig | ||
from .cache import CacheConfig | ||
from .compilation import CompilationConfig, CompilationLevel, CUDAGraphMode | ||
from .device import DeviceConfig | ||
|
@@ -68,6 +69,8 @@ class VllmConfig: | |
"""Device configuration.""" | ||
load_config: LoadConfig = field(default_factory=LoadConfig) | ||
"""Load configuration.""" | ||
attention_config: AttentionConfig = field(default_factory=AttentionConfig) | ||
"""Attention configuration.""" | ||
lora_config: Optional[LoRAConfig] = None | ||
"""LoRA configuration.""" | ||
speculative_config: Optional[SpeculativeConfig] = None | ||
|
@@ -153,6 +156,10 @@ def compute_hash(self) -> str: | |
vllm_factors.append(self.load_config.compute_hash()) | ||
else: | ||
vllm_factors.append("None") | ||
if self.attention_config: | ||
vllm_factors.append(self.attention_config.compute_hash()) | ||
else: | ||
vllm_factors.append("None") | ||
Comment on lines
+159
to
+162
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The vllm_factors.append(self.attention_config.compute_hash()) |
||
if self.lora_config: | ||
vllm_factors.append(self.lora_config.compute_hash()) | ||
# LoRA creates static buffers based on max_num_batched_tokens. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The manual listing of fields in
compute_hash
is fragile. If a new field is added toAttentionConfig
in the future, a developer might forget to update this list, leading to incorrect cache hashes. This can cause subtle and hard-to-debug issues.To make this more robust, you could programmatically collect the fields. This ensures that any new field is automatically included. Since all current fields seem to affect the computation graph, iterating over all fields is appropriate. This can be done without extra imports by using the
__dataclass_fields__
attribute.