@@ -154,25 +154,30 @@ def __init__(
154
154
f"Head size { head_size } is not supported by PagedAttention. "
155
155
f"Supported head sizes are: { suppored_head_sizes } ." )
156
156
157
- self .use_naive_attn = torch . cuda . get_device_capability ()[ 0 ] != 9
157
+ self .use_naive_attn = False
158
158
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
159
159
self .use_triton_flash_attn = (os .environ .get (
160
160
"VLLM_USE_TRITON_FLASH_ATTN" , "True" ).lower () in ("true" , "1" ))
161
- if self .use_naive_attn :
162
- # AMD Radeon 7900 series (gfx1100) currently does not support
163
- # xFormers nor FlashAttention. As a temporary workaround, we use
164
- # naive PyTorch implementation of attention.
165
- self .attn_fuc = _naive_attention
166
- logger .debug ("Using naive attention in ROCmBackend" )
167
- elif self .use_triton_flash_attn :
161
+ if self .use_triton_flash_attn :
168
162
from vllm .attention .ops .triton_flash_attention import ( # noqa: F401
169
163
triton_attention )
170
164
self .attn_func = triton_attention
171
165
logger .debug ("Using Triton FA in ROCmBackend" )
172
166
else :
173
- from flash_attn import flash_attn_varlen_func # noqa: F401
174
- self .attn_func = flash_attn_varlen_func
175
- logger .debug ("Using CK FA in ROCmBackend" )
167
+ # if not using triton, navi3x not use flash-attn either
168
+ if torch .cuda .get_device_capability ()[0 ] == 11 :
169
+ self .use_naive_attn = True
170
+ else :
171
+ try :
172
+ from flash_attn import flash_attn_varlen_func # noqa: F401
173
+ self .attn_func = flash_attn_varlen_func
174
+ logger .debug ("Using CK FA in ROCmBackend" )
175
+ except ModuleNotFoundError :
176
+ self .use_naive_attn = True
177
+
178
+ if self .use_naive_attn :
179
+ self .attn_func = _naive_attention
180
+ logger .debug ("Using naive attention in ROCmBackend" )
176
181
177
182
def repeat_kv (self , x : torch .Tensor , n_rep : int ) -> torch .Tensor :
178
183
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
@@ -247,13 +252,13 @@ def forward(
247
252
# triton attention
248
253
# When block_tables are not filled, it means q and k are the
249
254
# prompt, and they have the same length.
250
- if self .use_naive_attn or self .use_triton_flash_attn :
255
+ if self .use_triton_flash_attn or self .use_naive_attn :
251
256
if self .num_kv_heads != self .num_heads :
252
257
# Interleave for MQA workaround.
253
258
key = self .repeat_kv (key , self .num_queries_per_kv )
254
259
value = self .repeat_kv (value , self .num_queries_per_kv )
255
260
if self .use_naive_attn :
256
- out = self .attn_fuc (
261
+ out = self .attn_func (
257
262
query ,
258
263
key ,
259
264
value ,
0 commit comments