Skip to content

Commit 007dd90

Browse files
authored
[gpt-oss] Enable gpt-oss on ampere (#22714)
Signed-off-by: Yongye Zhu <[email protected]>
1 parent b8a9d0e commit 007dd90

File tree

10 files changed

+26
-17
lines changed

10 files changed

+26
-17
lines changed

tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
2525
compilation_config.custom_ops = ["all"]
2626

2727
def get_attn_backend_cls(self, backend_name, head_size, dtype,
28-
kv_cache_dtype, block_size, use_v1, use_mla):
29-
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
28+
kv_cache_dtype, block_size, use_v1, use_mla,
29+
has_sink):
30+
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501

vllm/attention/layer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def __init__(
138138
self.head_size = head_size
139139
self.num_kv_heads = num_kv_heads
140140
self.sliding_window = sliding_window
141+
self.has_sink = extra_impl_args.get("sinks") is not None
141142

142143
quant_method = quant_config.get_quant_method(
143144
self, prefix=prefix) if quant_config else None
@@ -165,7 +166,8 @@ def __init__(
165166
kv_cache_dtype,
166167
block_size,
167168
is_attention_free,
168-
use_mla=use_mla)
169+
use_mla=use_mla,
170+
has_sink=self.has_sink)
169171
else:
170172
self.attn_backend = attn_backend
171173

vllm/attention/selector.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def get_attn_backend(
144144
block_size: int,
145145
is_attention_free: bool = False,
146146
use_mla: bool = False,
147+
has_sink: bool = False,
147148
) -> type[AttentionBackend]:
148149
"""Selects which attention backend to use and lazily imports it."""
149150
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
@@ -158,6 +159,7 @@ def get_attn_backend(
158159
is_attention_free=is_attention_free,
159160
use_v1=envs.VLLM_USE_V1,
160161
use_mla=use_mla,
162+
has_sink=has_sink,
161163
)
162164

163165

@@ -170,6 +172,7 @@ def _cached_get_attn_backend(
170172
is_attention_free: bool,
171173
use_v1: bool = False,
172174
use_mla: bool = False,
175+
has_sink: bool = False,
173176
) -> type[AttentionBackend]:
174177
# If there are no attention layers (e.g. we are running Mamba),
175178
# use the placeholder NO_ATTENTION
@@ -201,7 +204,7 @@ def _cached_get_attn_backend(
201204
# get device-specific attn_backend
202205
attention_cls = current_platform.get_attn_backend_cls(
203206
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
204-
use_mla)
207+
use_mla, has_sink)
205208
if not attention_cls:
206209
raise ValueError(
207210
f"Invalid attention backend for {current_platform.device_name}")

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def from_config(cls, config):
4242

4343
@classmethod
4444
def get_min_capability(cls) -> int:
45-
return 90
45+
return 80
4646

4747
@classmethod
4848
def get_name(cls) -> QuantizationMethods:

vllm/platforms/cpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ def get_device_name(cls, device_id: int = 0) -> str:
9191
@classmethod
9292
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
9393
dtype: torch.dtype, kv_cache_dtype: Optional[str],
94-
block_size: int, use_v1: bool,
95-
use_mla: bool) -> str:
94+
block_size: int, use_v1: bool, use_mla: bool,
95+
has_sink: bool) -> str:
9696
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
9797
logger.info("Cannot use %s backend on CPU.", selected_backend)
9898
if use_mla:

vllm/platforms/cuda.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,8 @@ def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
222222

223223
@classmethod
224224
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
225-
kv_cache_dtype, block_size, use_v1,
226-
use_mla) -> str:
225+
kv_cache_dtype, block_size, use_v1, use_mla,
226+
has_sink) -> str:
227227
if use_mla:
228228
# TODO(lucas): refactor to be more concise
229229
# we should probably consider factoring out V1 here
@@ -321,6 +321,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
321321

322322
# FlashAttention is the default for SM 8.0+ GPUs
323323
if cls.has_device_capability(80):
324+
if has_sink:
325+
logger.info_once("Using Triton backend on V1 engine.")
326+
return TRITON_ATTN_VLLM_V1
324327
if is_default_backend_supported := is_attn_backend_supported(
325328
FLASH_ATTN_V1, head_size, dtype,
326329
allow_import_error=False):

vllm/platforms/interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
196196
@classmethod
197197
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
198198
dtype: torch.dtype, kv_cache_dtype: Optional[str],
199-
block_size: int, use_v1: bool,
200-
use_mla: bool) -> str:
199+
block_size: int, use_v1: bool, use_mla: bool,
200+
has_sink: bool) -> str:
201201
"""Get the attention backend class of a device."""
202202
return ""
203203

vllm/platforms/rocm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,8 @@ def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
188188

189189
@classmethod
190190
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
191-
kv_cache_dtype, block_size, use_v1,
192-
use_mla) -> str:
191+
kv_cache_dtype, block_size, use_v1, use_mla,
192+
has_sink) -> str:
193193
if use_mla:
194194
from vllm.attention.backends.rocm_aiter_mla import (
195195
is_aiter_mla_enabled)

vllm/platforms/tpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ class TpuPlatform(Platform):
4646
@classmethod
4747
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
4848
dtype: torch.dtype, kv_cache_dtype: Optional[str],
49-
block_size: int, use_v1: bool,
50-
use_mla: bool) -> str:
49+
block_size: int, use_v1: bool, use_mla: bool,
50+
has_sink) -> str:
5151
if (selected_backend != _Backend.PALLAS
5252
and selected_backend != _Backend.PALLAS_VLLM_V1):
5353
logger.info("Cannot use %s backend on TPU.", selected_backend)

vllm/platforms/xpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ class XPUPlatform(Platform):
3535
@classmethod
3636
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
3737
dtype: torch.dtype, kv_cache_dtype: Optional[str],
38-
block_size: int, use_v1: bool,
39-
use_mla: bool) -> str:
38+
block_size: int, use_v1: bool, use_mla: bool,
39+
has_sink: bool) -> str:
4040
if selected_backend is not None and selected_backend != _Backend.IPEX:
4141
logger.info("Cannot use %s backend on XPU.", selected_backend)
4242
use_v1 = envs.VLLM_USE_V1

0 commit comments

Comments
 (0)