Skip to content

Commit c6b9287

Browse files
authored
Force TRTLLM attention for gpt-oss on SM100 (#22678)
Signed-off-by: mgoin <[email protected]>
1 parent b1361c7 commit c6b9287

File tree

4 files changed

+20
-9
lines changed

4 files changed

+20
-9
lines changed

vllm/model_executor/models/gpt_oss.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torch import nn
99
from transformers import GptOssConfig
1010

11-
from vllm import envs
1211
from vllm.attention import Attention, AttentionType
1312
from vllm.compilation.decorators import support_torch_compile
1413
from vllm.config import CacheConfig, VllmConfig
@@ -70,11 +69,9 @@ def __init__(
7069

7170
tp_size = get_tensor_model_parallel_world_size()
7271

73-
attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION
74-
else torch.bfloat16)
7572
self.sinks = torch.nn.Parameter(
7673
torch.empty(config.num_attention_heads // tp_size,
77-
dtype=attention_sink_dtype,
74+
dtype=torch.bfloat16,
7875
requires_grad=False))
7976

8077
self.norm = RMSNorm(config.hidden_size, eps=1e-5)

vllm/utils/flashinfer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def use_trtllm_attention(
154154
num_qo_heads: Optional[int],
155155
num_kv_heads: Optional[int],
156156
attn_head_size: Optional[int],
157+
has_sinks: bool = False,
157158
) -> bool:
158159
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
159160
if not (current_platform.is_device_capability(100)
@@ -165,6 +166,13 @@ def use_trtllm_attention(
165166
or num_qo_heads % num_kv_heads != 0):
166167
return False
167168

169+
# If sinks are being used, we must use TRTLLM attention as it's
170+
# the only backend that supports them
171+
if has_sinks:
172+
logger.info_once(
173+
"Using TRTLLM attention (required for attention sinks).")
174+
return True
175+
168176
env_value = envs.VLLM_USE_TRTLLM_ATTENTION
169177
if env_value is not None:
170178
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)

vllm/v1/attention/backends/flashinfer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -523,14 +523,17 @@ def build(self,
523523
num_kv_heads = self.kv_cache_spec.num_kv_heads
524524
head_dim = self.kv_cache_spec.head_size
525525

526+
# Check if any layer uses sinks (requires TRTLLM attention)
527+
has_sinks = self.global_hyperparameters.has_sinks
528+
526529
# currently prefill trtllm attention does not support fp8 kv cache
527530
prefill_use_trtllm = not cache_dtype.startswith("fp8") \
528531
and use_trtllm_attention(
529532
num_prefill_tokens, max_seq_len, cache_dtype,
530-
num_qo_heads, num_kv_heads, head_dim)
533+
num_qo_heads, num_kv_heads, head_dim, has_sinks)
531534
decode_use_trtllm = use_trtllm_attention(
532535
num_decode_tokens, max_seq_len, cache_dtype,
533-
num_qo_heads, num_kv_heads, head_dim)
536+
num_qo_heads, num_kv_heads, head_dim, has_sinks)
534537

535538
attn_metadata = FlashInferMetadata(
536539
num_actual_tokens=num_actual_tokens,
@@ -642,9 +645,9 @@ def __init__(
642645
f"heads in the layer. Expected {num_heads}, but got "
643646
f"{sinks.shape[0]}."
644647
)
648+
# Cast sinks to float32 if needed (FlashInfer requirement)
645649
if sinks.dtype != torch.float32:
646-
raise ValueError("Sinks must be of type float32, but got "
647-
f"{sinks.dtype}.")
650+
sinks = sinks.to(torch.float32)
648651
self.sinks = sinks
649652

650653
def forward(

vllm/v1/attention/backends/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ class PerLayerParameters:
285285
window_left: int
286286
logits_soft_cap: Optional[float]
287287
sm_scale: float
288+
has_sinks: bool = False
288289

289290

290291
def get_per_layer_parameters(
@@ -307,9 +308,11 @@ def get_per_layer_parameters(
307308
window_left = window_size[0] if window_size is not None else -1
308309
logits_soft_cap = getattr(impl, "logits_soft_cap", None)
309310
sm_scale = impl.scale
311+
has_sinks = getattr(impl, "sinks", None) is not None
310312

311313
per_layer_params[key] = PerLayerParameters(window_left,
312-
logits_soft_cap, sm_scale)
314+
logits_soft_cap, sm_scale,
315+
has_sinks)
313316

314317
return per_layer_params
315318

0 commit comments

Comments
 (0)