Skip to content

Commit d3a6f21

Browse files
[FEAT][ROCm] Enable running Flash Attention as ViT attn backend for Qwen-VL models on ROCm platform. (#22069)
Signed-off-by: tjtanaavllm <[email protected]> Signed-off-by: vllmellm <[email protected]> Co-authored-by: tjtanaavllm <[email protected]>
1 parent 0edaf75 commit d3a6f21

File tree

6 files changed

+64
-39
lines changed

6 files changed

+64
-39
lines changed

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,11 +246,15 @@ def __init__(
246246
# Detect attention implementation.
247247
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
248248
if self.attn_backend not in {
249-
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
249+
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
250+
_Backend.ROCM_AITER_FA
250251
}:
251252
raise RuntimeError(
252253
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
253254
)
255+
self.is_flash_attn_backend = self.attn_backend in {
256+
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
257+
}
254258

255259
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
256260
# [s, b, 3 * head * head_dim]
@@ -297,10 +301,13 @@ def forward(
297301
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
298302
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
299303

300-
if self.attn_backend == _Backend.FLASH_ATTN:
304+
if self.is_flash_attn_backend:
301305
# from vllm_flash_attn.flash_attn_interface import (
302306
# flash_attn_varlen_func)
303-
from flash_attn import flash_attn_varlen_func
307+
if self.attn_backend == _Backend.ROCM_AITER_FA:
308+
from aiter import flash_attn_varlen_func
309+
else:
310+
from flash_attn import flash_attn_varlen_func
304311

305312
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
306313

@@ -311,7 +318,7 @@ def forward(
311318
cu_seqlens_k=cu_seqlens,
312319
max_seqlen_q=max_seqlen,
313320
max_seqlen_k=max_seqlen,
314-
dropout_p=0,
321+
dropout_p=0.0,
315322
causal=False)
316323

317324
context_layer = rearrange(output,
@@ -635,7 +642,8 @@ def compute_attn_mask_seqlen(
635642
cu_seqlens: torch.Tensor,
636643
) -> tuple[Optional[int], Optional[list[int]]]:
637644
max_seqlen, seqlens = None, None
638-
if self.attn_backend == _Backend.FLASH_ATTN:
645+
if (self.attn_backend == _Backend.FLASH_ATTN
646+
or self.attn_backend == _Backend.ROCM_AITER_FA):
639647
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
640648
elif self.attn_backend == _Backend.XFORMERS:
641649
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

vllm/model_executor/models/qwen2_vl.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,14 @@ def __init__(
274274
# Detect attention implementation.
275275
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
276276
if self.attn_backend not in {
277-
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
277+
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
278+
_Backend.ROCM_AITER_FA
278279
}:
279280
raise RuntimeError(
280281
f"Qwen2-VL does not support {self.attn_backend} backend now.")
282+
self.is_flash_attn_backend = self.attn_backend in {
283+
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
284+
}
281285

282286
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
283287
# [s, b, 3 * head * head_dim]
@@ -324,10 +328,13 @@ def forward(
324328
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
325329
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
326330

327-
if self.attn_backend == _Backend.FLASH_ATTN:
331+
if self.is_flash_attn_backend:
328332
# from vllm_flash_attn.flash_attn_interface import (
329333
# flash_attn_varlen_func)
330-
from flash_attn import flash_attn_varlen_func
334+
if self.attn_backend == _Backend.ROCM_AITER_FA:
335+
from aiter import flash_attn_varlen_func
336+
else:
337+
from flash_attn import flash_attn_varlen_func
331338

332339
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
333340

@@ -338,7 +345,7 @@ def forward(
338345
cu_seqlens_k=cu_seqlens,
339346
max_seqlen_q=max_seqlen,
340347
max_seqlen_k=max_seqlen,
341-
dropout_p=0,
348+
dropout_p=0.0,
342349
causal=False)
343350

344351
context_layer = rearrange(output,
@@ -620,7 +627,8 @@ def compute_attn_mask_seqlen(
620627
self, cu_seqlens: torch.Tensor
621628
) -> tuple[Optional[int], Optional[list[int]]]:
622629
max_seqlen, seqlens = None, None
623-
if self.attn_backend == _Backend.FLASH_ATTN:
630+
if (self.attn_backend == _Backend.FLASH_ATTN
631+
or self.attn_backend == _Backend.ROCM_AITER_FA):
624632
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
625633
elif self.attn_backend == _Backend.XFORMERS:
626634
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

vllm/model_executor/models/vision.py

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
import torch
88
from transformers import PretrainedConfig
99

10-
import vllm.envs as envs
11-
from vllm.attention.selector import (backend_name_to_enum,
12-
get_global_forced_attn_backend)
10+
from vllm.attention.selector import get_env_variable_attn_backend
1311
from vllm.logger import init_logger
1412
from vllm.platforms import _Backend, current_platform
1513

@@ -75,32 +73,12 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
7573
Get the available attention backend for Vision Transformer.
7674
"""
7775
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
78-
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
79-
if selected_backend is None:
80-
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
81-
if backend_by_env_var is not None:
82-
selected_backend = backend_name_to_enum(backend_by_env_var)
83-
if selected_backend is None:
84-
if current_platform.is_cuda():
85-
device_available = current_platform.has_device_capability(80)
86-
if device_available and support_fa:
87-
from transformers.utils import is_flash_attn_2_available
88-
if is_flash_attn_2_available():
89-
selected_backend = _Backend.FLASH_ATTN
90-
else:
91-
logger.warning_once(
92-
"Current `vllm-flash-attn` has a bug inside vision "
93-
"module, so we use xformers backend instead. You can "
94-
"run `pip install flash-attn` to use flash-attention "
95-
"backend.")
96-
selected_backend = _Backend.XFORMERS
97-
else:
98-
# For Volta and Turing GPUs, use xformers instead.
99-
selected_backend = _Backend.XFORMERS
100-
else:
101-
# Default to torch SDPA for other non-GPU platforms.
102-
selected_backend = _Backend.TORCH_SDPA
103-
return selected_backend
76+
77+
selected_backend: Optional[_Backend] = get_env_variable_attn_backend()
78+
if selected_backend is not None:
79+
return selected_backend
80+
81+
return current_platform.get_vit_attn_backend(support_fa)
10482

10583

10684
def resolve_visual_encoder_outputs(

vllm/platforms/cuda.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,20 @@ def get_current_memory_usage(cls,
206206
torch.cuda.reset_peak_memory_stats(device)
207207
return torch.cuda.max_memory_allocated(device)
208208

209+
@classmethod
210+
def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
211+
if cls.has_device_capability(80) and support_fa:
212+
from transformers.utils import is_flash_attn_2_available
213+
if is_flash_attn_2_available():
214+
return _Backend.FLASH_ATTN
215+
logger.warning_once(
216+
"Current `vllm-flash-attn` has a bug inside vision "
217+
"module, so we use xformers backend instead. You can "
218+
"run `pip install flash-attn` to use flash-attention "
219+
"backend.")
220+
# Fallback for Volta/Turing GPUs or FA not supported
221+
return _Backend.XFORMERS
222+
209223
@classmethod
210224
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
211225
kv_cache_dtype, block_size, use_v1,

vllm/platforms/interface.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class _Backend(enum.Enum):
4646
ROCM_FLASH = enum.auto()
4747
ROCM_AITER_MLA = enum.auto() # Supported by V1
4848
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
49+
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
4950
TORCH_SDPA = enum.auto()
5051
FLASHINFER = enum.auto()
5152
FLASHINFER_VLLM_V1 = enum.auto()
@@ -186,6 +187,10 @@ def device_id_to_physical_device_id(cls, device_id: int):
186187
else:
187188
return device_id
188189

190+
@classmethod
191+
def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
192+
return _Backend.TORCH_SDPA
193+
189194
@classmethod
190195
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
191196
dtype: torch.dtype, kv_cache_dtype: Optional[str],

vllm/platforms/rocm.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,18 @@ class RocmPlatform(Platform):
173173
"quark", "ptpc_fp8"
174174
]
175175

176+
@classmethod
177+
def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
178+
if support_fa:
179+
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
180+
and on_gfx9()):
181+
# Note: AITER FA is only supported for Qwen-VL models.
182+
# TODO: Add support for other VL models in their model class.
183+
return _Backend.ROCM_AITER_FA
184+
if on_gfx9():
185+
return _Backend.FLASH_ATTN
186+
return _Backend.TORCH_SDPA
187+
176188
@classmethod
177189
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
178190
kv_cache_dtype, block_size, use_v1,

0 commit comments

Comments
 (0)