|
10 | 10 |
|
11 | 11 | import torch
|
12 | 12 | import vllm_hpu_extension.ops as ops
|
13 |
| -from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache |
| 13 | +from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, |
| 14 | + VLLMKVCache) |
14 | 15 |
|
15 | 16 | from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
16 | 17 | AttentionLayer,
|
@@ -137,9 +138,17 @@ def __init__(
|
137 | 138 |
|
138 | 139 | self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
|
139 | 140 | '0').lower() in ['1', 'true']
|
| 141 | + self.fused_scaled_dot_product_attention = None |
140 | 142 | if self.prefill_usefusedsdpa:
|
141 | 143 | assert alibi_slopes is None, \
|
142 | 144 | 'Prefill with FusedSDPA not supported with alibi slopes!'
|
| 145 | + try: |
| 146 | + from habana_frameworks.torch.hpex.kernels import FusedSDPA |
| 147 | + self.fused_scaled_dot_product_attention = ModuleFusedSDPA( |
| 148 | + FusedSDPA) |
| 149 | + except ImportError: |
| 150 | + logger().warning("Could not import HPU FusedSDPA kernel. " |
| 151 | + "vLLM will use native implementation.") |
143 | 152 |
|
144 | 153 | suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
|
145 | 154 | if head_size not in suppored_head_sizes:
|
@@ -227,6 +236,7 @@ def forward(
|
227 | 236 | matmul_qk_op=self.matmul_qk,
|
228 | 237 | softmax_op=self.softmax,
|
229 | 238 | matmul_av_op=self.matmul_av,
|
| 239 | + fsdpa_op=self.fused_scaled_dot_product_attention, |
230 | 240 | )
|
231 | 241 | output = out.reshape(batch_size, seq_len, hidden_size)
|
232 | 242 | else:
|
|
0 commit comments