Skip to content

Commit af8486d

Browse files
[Hardware][Intel-Gaudi] Enable FusedSDPA support for Intel Gaudi (HPU)
1 parent 4c3aac5 commit af8486d

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

vllm/attention/backends/hpu_attn.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
import torch
1212
import vllm_hpu_extension.ops as ops
13-
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
13+
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
14+
VLLMKVCache)
1415

1516
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1617
AttentionLayer,
@@ -137,9 +138,17 @@ def __init__(
137138

138139
self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
139140
'0').lower() in ['1', 'true']
141+
self.fused_scaled_dot_product_attention = None
140142
if self.prefill_usefusedsdpa:
141143
assert alibi_slopes is None, \
142144
'Prefill with FusedSDPA not supported with alibi slopes!'
145+
try:
146+
from habana_frameworks.torch.hpex.kernels import FusedSDPA
147+
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
148+
FusedSDPA)
149+
except ImportError:
150+
logger().warning("Could not import HPU FusedSDPA kernel. "
151+
"vLLM will use native implementation.")
143152

144153
suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
145154
if head_size not in suppored_head_sizes:
@@ -227,6 +236,7 @@ def forward(
227236
matmul_qk_op=self.matmul_qk,
228237
softmax_op=self.softmax,
229238
matmul_av_op=self.matmul_av,
239+
fsdpa_op=self.fused_scaled_dot_product_attention,
230240
)
231241
output = out.reshape(batch_size, seq_len, hidden_size)
232242
else:

0 commit comments

Comments
 (0)