196196import torch
197197
198198from vllm import _custom_ops as ops
199+ from vllm import envs
199200from vllm .attention .backends .abstract import (AttentionBackend , AttentionLayer ,
200201 AttentionMetadata ,
201202 AttentionMetadataBuilder ,
215216from vllm .utils import async_tensor_h2d , cdiv , make_tensor_with_pad , round_down
216217from vllm .vllm_flash_attn .fa_utils import get_flash_attn_version
217218
219+ if HAS_TRITON :
220+ from vllm .attention .ops .triton_flash_attention import triton_attention
221+ else :
222+ triton_attention = None
218223
219224try :
220225 from vllm .vllm_flash_attn import flash_attn_varlen_func
@@ -1039,6 +1044,7 @@ def __init__(
10391044 self .kv_b_proj = kv_b_proj
10401045 self .o_proj = o_proj
10411046
1047+ self .triton_fa_func = triton_attention
10421048 # Handle the differences between the flash_attn_varlen from flash_attn
10431049 # and the one from vllm_flash_attn. The former is used on RoCM and the
10441050 # latter has an additional parameter to control FA2 vs FA3
@@ -1064,6 +1070,14 @@ def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale,
10641070 maybe_padded_v = torch .nn .functional .pad (
10651071 v , [0 , q .shape [- 1 ] - v .shape [- 1 ]], value = 0 )
10661072
1073+ if is_hip and envs .VLLM_USE_TRITON_FLASH_ATTN \
1074+ and not return_softmax_lse :
1075+ attn_out = self .triton_fa_func (
1076+ q ,
1077+ k ,
1078+ maybe_padded_v ,
1079+ ** kwargs ,
1080+ )
10671081 if is_vllm_fa :
10681082 attn_out = self .flash_attn_varlen_func (
10691083 q = q ,
0 commit comments