Skip to content

Commit 469b3ff

Browse files
[V1] port xformers backend to v1 (#21342)
Signed-off-by: Giancarlo Delfin <[email protected]>
1 parent ae87ddd commit 469b3ff

File tree

6 files changed

+438
-1
lines changed

6 files changed

+438
-1
lines changed

tests/v1/attention/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ def get_attention_backend(backend_name: _Backend):
128128
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
129129
_Backend.TREE_ATTN:
130130
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
131+
_Backend.XFORMERS_VLLM_V1:
132+
"vllm.v1.attention.backends.xformers.XFormersAttentionBackend",
131133
}
132134

133135
if backend_name not in backend_map:

vllm/engine/arg_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1469,6 +1469,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
14691469
"TORCH_SDPA_VLLM_V1",
14701470
"FLEX_ATTENTION",
14711471
"TREE_ATTN",
1472+
"XFORMERS_VLLM_V1",
14721473
]
14731474
if (envs.is_set("VLLM_ATTENTION_BACKEND")
14741475
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):

vllm/platforms/cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
271271
TRITON_ATTN_VLLM_V1 = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
272272
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
273273
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
274+
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
274275

275276
if selected_backend == _Backend.FLASHINFER:
276277
logger.info_once("Using FlashInfer backend on V1 engine.")
@@ -291,6 +292,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
291292
elif selected_backend == _Backend.TREE_ATTN:
292293
logger.info_once("Using Tree Attention backend on V1 engine.")
293294
return TREE_ATTN_V1
295+
elif selected_backend == _Backend.XFORMERS_VLLM_V1:
296+
logger.info_once("Using XFormers backend on V1 engine.")
297+
return XFORMERS_V1
294298

295299
from vllm.attention.selector import is_attn_backend_supported
296300

vllm/platforms/interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class _Backend(enum.Enum):
6363
NO_ATTENTION = enum.auto()
6464
FLEX_ATTENTION = enum.auto()
6565
TREE_ATTN = enum.auto()
66+
XFORMERS_VLLM_V1 = enum.auto()
6667

6768

6869
class PlatformEnum(enum.Enum):

vllm/v1/attention/backends/tree_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,6 @@ def __init__(
316316
logits_soft_cap: Optional[float] = None,
317317
attn_type: AttentionType = AttentionType.DECODER,
318318
kv_sharing_target_layer_name: Optional[str] = None,
319-
use_irope: bool = False,
320319
) -> None:
321320
self.num_heads = num_heads
322321
self.head_size = head_size

0 commit comments

Comments
 (0)