Skip to content

Commit 175811e

Browse files
authored
[V1][Attention] Split triton_attn in triton-only and rocm specific backends (#24648)
Signed-off-by: Burkhard Ringlein <[email protected]>
1 parent c10101a commit 175811e

File tree

5 files changed

+483
-124
lines changed

5 files changed

+483
-124
lines changed

vllm/engine/arg_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,6 +1494,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14941494
"FLEX_ATTENTION",
14951495
"TREE_ATTN",
14961496
"XFORMERS_VLLM_V1",
1497+
"ROCM_ATTN_VLLM_V1",
14971498
]
14981499
if (envs.is_set("VLLM_ATTENTION_BACKEND")
14991500
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):

vllm/platforms/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class _Backend(enum.Enum):
6767
FLEX_ATTENTION = enum.auto()
6868
TREE_ATTN = enum.auto()
6969
XFORMERS_VLLM_V1 = enum.auto()
70+
ROCM_ATTN_VLLM_V1 = enum.auto()
7071

7172

7273
class PlatformEnum(enum.Enum):

vllm/platforms/rocm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,17 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
231231
logger.info("Using Flash Attention backend on V1 engine.")
232232
return ("vllm.v1.attention.backends."
233233
"rocm_aiter_fa.AiterFlashAttentionBackend")
234+
elif (envs.VLLM_ROCM_USE_AITER and
235+
envs.VLLM_USE_AITER_UNIFIED_ATTENTION) or \
236+
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or \
237+
selected_backend == _Backend.ROCM_ATTN_VLLM_V1:
238+
# rocm specific backend, with aiter and/or
239+
# triton prefix-prefill
240+
logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
241+
return ("vllm.v1.attention.backends."
242+
"rocm_attn.RocmAttentionBackend")
234243
else:
244+
# default case, using triton unified attention
235245
logger.info("Using Triton Attention backend on V1 engine.")
236246
return ("vllm.v1.attention.backends."
237247
"triton_attn.TritonAttentionBackend")

0 commit comments

Comments
 (0)