Skip to content

Commit 000ccec

Browse files
authored
[Bugfix gpt-oss] Fix float32 convert for flashinfer sink support (#23016)
Signed-off-by: mgoin <[email protected]>
1 parent 68373d3 commit 000ccec

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

vllm/attention/layer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,15 @@ def process_weights_after_loading(self, act_dtype: torch.dtype):
308308
if hasattr(self.impl, "process_weights_after_loading"):
309309
self.impl.process_weights_after_loading(act_dtype)
310310

311+
# FlashInfer requires attention sinks to be float32
312+
if (self.backend == _Backend.FLASHINFER_VLLM_V1
313+
and hasattr(self.impl, 'sinks')):
314+
from vllm.v1.attention.backends.flashinfer import FlashInferImpl
315+
assert isinstance(self.impl, FlashInferImpl)
316+
if (self.impl.sinks is not None
317+
and self.impl.sinks.dtype != torch.float32):
318+
self.impl.sinks = self.impl.sinks.to(torch.float32)
319+
311320
def get_attn_backend(self) -> type[AttentionBackend]:
312321
return self.attn_backend
313322

vllm/v1/attention/backends/flashinfer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -642,9 +642,6 @@ def __init__(
642642
f"heads in the layer. Expected {num_heads}, but got "
643643
f"{sinks.shape[0]}."
644644
)
645-
# Cast sinks to float32 if needed (FlashInfer requirement)
646-
if sinks.dtype != torch.float32:
647-
sinks = sinks.to(torch.float32)
648645
self.sinks = sinks
649646

650647
def forward(

0 commit comments

Comments
 (0)