Skip to content

Commit 95f684f

Browse files
committed
address v0.11.0
Signed-off-by: wangxiyuan <[email protected]>
1 parent 509511f commit 95f684f

File tree

1 file changed

+143
-71
lines changed

1 file changed

+143
-71
lines changed

vllm_ascend/patch/worker/patch_common/patch_attention_selector.py

Lines changed: 143 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -26,82 +26,154 @@
2626
from vllm.platforms import _Backend, current_platform
2727
from vllm.utils import resolve_obj_by_qualname
2828

29+
from vllm_ascend.utils import vllm_version_is
2930

30-
def get_attn_backend(
31-
head_size: int,
32-
dtype: torch.dtype,
33-
kv_cache_dtype: Optional[str],
34-
block_size: int,
35-
is_attention_free: bool = False,
36-
use_mla: bool = False,
37-
use_sfa: bool = False,
38-
has_sink: bool = False,
39-
) -> type[AttentionBackend]:
40-
"""Selects which attention backend to use and lazily imports it."""
41-
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
42-
# value to be returned from the cache if the value changes between calls.
43-
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
44-
# private function.
45-
return _cached_get_attn_backend(
46-
head_size=head_size,
47-
dtype=dtype,
48-
kv_cache_dtype=kv_cache_dtype,
49-
block_size=block_size,
50-
is_attention_free=is_attention_free,
51-
use_v1=envs.VLLM_USE_V1,
52-
use_mla=use_mla,
53-
use_sfa=use_sfa,
54-
has_sink=has_sink,
55-
)
31+
if vllm_version_is("0.10.2"):
5632

33+
def get_attn_backend(
34+
head_size: int,
35+
dtype: torch.dtype,
36+
kv_cache_dtype: Optional[str],
37+
block_size: int,
38+
is_attention_free: bool = False,
39+
use_mla: bool = False,
40+
use_sfa: bool = False,
41+
has_sink: bool = False,
42+
) -> type[AttentionBackend]:
43+
"""Selects which attention backend to use and lazily imports it."""
44+
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
45+
# value to be returned from the cache if the value changes between calls.
46+
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
47+
# private function.
48+
return _cached_get_attn_backend(
49+
head_size=head_size,
50+
dtype=dtype,
51+
kv_cache_dtype=kv_cache_dtype,
52+
block_size=block_size,
53+
is_attention_free=is_attention_free,
54+
use_v1=envs.VLLM_USE_V1,
55+
use_mla=use_mla,
56+
use_sfa=use_sfa,
57+
has_sink=has_sink,
58+
)
5759

58-
@cache
59-
def _cached_get_attn_backend(
60-
head_size: int,
61-
dtype: torch.dtype,
62-
kv_cache_dtype: Optional[str],
63-
block_size: int,
64-
is_attention_free: bool,
65-
use_v1: bool = False,
66-
use_mla: bool = False,
67-
use_sfa: bool = False,
68-
has_sink: bool = False,
69-
) -> type[AttentionBackend]:
70-
# If there are no attention layers (e.g. we are running Mamba),
71-
# use the placeholder NO_ATTENTION
72-
if is_attention_free:
73-
from vllm.attention.backends.placeholder_attn import \
74-
PlaceholderAttentionBackend
75-
return PlaceholderAttentionBackend
60+
@cache
61+
def _cached_get_attn_backend(
62+
head_size: int,
63+
dtype: torch.dtype,
64+
kv_cache_dtype: Optional[str],
65+
block_size: int,
66+
is_attention_free: bool,
67+
use_v1: bool = False,
68+
use_mla: bool = False,
69+
use_sfa: bool = False,
70+
has_sink: bool = False,
71+
) -> type[AttentionBackend]:
72+
# If there are no attention layers (e.g. we are running Mamba),
73+
# use the placeholder NO_ATTENTION
74+
if is_attention_free:
75+
from vllm.attention.backends.placeholder_attn import \
76+
PlaceholderAttentionBackend
77+
return PlaceholderAttentionBackend
7678

77-
# Check whether a particular choice of backend was
78-
# previously forced.
79-
#
80-
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
81-
# ENVIRONMENT VARIABLE.
82-
selected_backend = None
83-
backend_by_global_setting: Optional[_Backend] = (
84-
get_global_forced_attn_backend())
85-
if backend_by_global_setting is not None:
86-
selected_backend = backend_by_global_setting
87-
else:
88-
# Check the environment variable and override if specified
89-
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
90-
if backend_by_env_var is not None:
91-
selected_backend = backend_name_to_enum(backend_by_env_var)
92-
if selected_backend is None:
93-
raise ValueError(
94-
f"Invalid attention backend: '{backend_by_env_var}'. "
95-
f"Valid backends are: {list(_Backend.__members__.keys())}")
79+
# Check whether a particular choice of backend was
80+
# previously forced.
81+
#
82+
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
83+
# ENVIRONMENT VARIABLE.
84+
selected_backend = None
85+
backend_by_global_setting: Optional[_Backend] = (
86+
get_global_forced_attn_backend())
87+
if backend_by_global_setting is not None:
88+
selected_backend = backend_by_global_setting
89+
else:
90+
# Check the environment variable and override if specified
91+
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
92+
if backend_by_env_var is not None:
93+
selected_backend = backend_name_to_enum(backend_by_env_var)
94+
if selected_backend is None:
95+
raise ValueError(
96+
f"Invalid attention backend: '{backend_by_env_var}'. "
97+
f"Valid backends are: {list(_Backend.__members__.keys())}"
98+
)
9699

97-
# get device-specific attn_backend
98-
attention_cls = current_platform.get_attn_backend_cls(
99-
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
100-
use_mla, use_sfa, has_sink)
101-
if not attention_cls:
102-
raise ValueError(
103-
f"Invalid attention backend for {current_platform.device_name}")
104-
return resolve_obj_by_qualname(attention_cls)
100+
# get device-specific attn_backend
101+
attention_cls = current_platform.get_attn_backend_cls(
102+
selected_backend, head_size, dtype, kv_cache_dtype, block_size,
103+
use_v1, use_mla, use_sfa, has_sink)
104+
if not attention_cls:
105+
raise ValueError(
106+
f"Invalid attention backend for {current_platform.device_name}"
107+
)
108+
return resolve_obj_by_qualname(attention_cls)
109+
else:
110+
111+
def get_attn_backend(
112+
head_size: int,
113+
dtype: torch.dtype,
114+
kv_cache_dtype: Optional[str],
115+
block_size: int,
116+
use_mla: bool = False,
117+
use_sfa: bool = False,
118+
has_sink: bool = False,
119+
) -> type[AttentionBackend]:
120+
"""Selects which attention backend to use and lazily imports it."""
121+
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
122+
# value to be returned from the cache if the value changes between calls.
123+
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
124+
# private function.
125+
return _cached_get_attn_backend(
126+
head_size=head_size,
127+
dtype=dtype,
128+
kv_cache_dtype=kv_cache_dtype,
129+
block_size=block_size,
130+
use_v1=envs.VLLM_USE_V1,
131+
use_mla=use_mla,
132+
use_sfa=use_sfa,
133+
has_sink=has_sink,
134+
)
135+
136+
@cache
137+
def _cached_get_attn_backend(
138+
head_size: int,
139+
dtype: torch.dtype,
140+
kv_cache_dtype: Optional[str],
141+
block_size: int,
142+
use_v1: bool = False,
143+
use_mla: bool = False,
144+
use_sfa: bool = False,
145+
has_sink: bool = False,
146+
) -> type[AttentionBackend]:
147+
# Check whether a particular choice of backend was
148+
# previously forced.
149+
#
150+
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
151+
# ENVIRONMENT VARIABLE.
152+
selected_backend = None
153+
backend_by_global_setting: Optional[_Backend] = (
154+
get_global_forced_attn_backend())
155+
if backend_by_global_setting is not None:
156+
selected_backend = backend_by_global_setting
157+
else:
158+
# Check the environment variable and override if specified
159+
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
160+
if backend_by_env_var is not None:
161+
selected_backend = backend_name_to_enum(backend_by_env_var)
162+
if selected_backend is None:
163+
raise ValueError(
164+
f"Invalid attention backend: '{backend_by_env_var}'. "
165+
f"Valid backends are: {list(_Backend.__members__.keys())}"
166+
)
167+
168+
# get device-specific attn_backend
169+
attention_cls = current_platform.get_attn_backend_cls(
170+
selected_backend, head_size, dtype, kv_cache_dtype, block_size,
171+
use_v1, use_mla, use_sfa, has_sink)
172+
if not attention_cls:
173+
raise ValueError(
174+
f"Invalid attention backend for {current_platform.device_name}"
175+
)
176+
return resolve_obj_by_qualname(attention_cls)
105177

106178

107179
vllm.attention.get_attn_backend = get_attn_backend

0 commit comments

Comments
 (0)