@@ -513,10 +513,6 @@ def make_backend(backend_name: str) -> AttentionBackend:
513
513
Construct the backend instance determined by the backend_name string
514
514
argument.
515
515
516
- "XFORMERS" -> construct xformers backend
517
-
518
- TODO: other backends
519
-
520
516
Note: at time of writing the Attention wrapper automatically selects
521
517
its own backend for Attention.forward(); so the backend instance which
522
518
you generate with this function is not meant to be used for *running*
@@ -528,18 +524,68 @@ def make_backend(backend_name: str) -> AttentionBackend:
528
524
529
525
* Backend instance
530
526
'''
531
- if backend_name == STR_XFORMERS_ATTN_VAL :
532
- # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
533
- from vllm . attention . backends . xformers import XFormersBackend
534
- return XFormersBackend ()
535
- elif backend_name == STR_FLASH_ATTN_VAL :
536
- from vllm .attention .backends .flash_attn import FlashAttentionBackend
527
+ if backend_name in ( STR_XFORMERS_ATTN_VAL , "XFORMERS_VLLM_V1" ) :
528
+ from vllm . v1 . attention . backends . xformers import (
529
+ XFormersAttentionBackend )
530
+ return XFormersAttentionBackend ()
531
+ if backend_name in ( STR_FLASH_ATTN_VAL , "FLASH_ATTN_VLLM_V1" ) :
532
+ from vllm .v1 . attention .backends .flash_attn import FlashAttentionBackend
537
533
return FlashAttentionBackend ()
534
+ if backend_name == "TRITON_ATTN_VLLM_V1" :
535
+ from vllm .v1 .attention .backends .triton_attn import (
536
+ TritonAttentionBackend )
537
+ return TritonAttentionBackend ()
538
+ if backend_name == "FLEX_ATTENTION" :
539
+ from vllm .v1 .attention .backends .flex_attention import (
540
+ FlexAttentionBackend )
541
+ return FlexAttentionBackend ()
542
+ if backend_name in ("TORCH_SDPA" , "TORCH_SDPA_VLLM_V1" ):
543
+ from vllm .v1 .attention .backends .cpu_attn import TorchSDPABackend
544
+ return TorchSDPABackend ()
545
+ if backend_name == "FLASHINFER" :
546
+ from vllm .v1 .attention .backends .flashinfer import FlashInferBackend
547
+ return FlashInferBackend ()
538
548
539
549
raise AssertionError (
540
550
f"Unrecognized backend_name { backend_name } for unit test" )
541
551
542
552
553
+ def make_alibi_bias (
554
+ alibi_slopes : torch .Tensor ,
555
+ num_kv_heads : int ,
556
+ dtype : torch .dtype ,
557
+ seq_lens : list [int ],
558
+ ) -> list [Any ]:
559
+ """Create ALiBi biases compatible with xFormers attention tests."""
560
+ from xformers .ops .fmha .attn_bias import LowerTriangularMaskWithTensorBias
561
+
562
+ if alibi_slopes is None :
563
+ return [None for _ in seq_lens ]
564
+
565
+ attn_biases : list [Any ] = []
566
+ num_heads = alibi_slopes .shape [0 ]
567
+ assert num_heads >= num_kv_heads , (
568
+ "ALiBi slopes expect at least as many heads as KV heads" )
569
+
570
+ for seq_len in seq_lens :
571
+ bias = torch .arange (seq_len , dtype = dtype , device = alibi_slopes .device )
572
+ bias = bias [None , :] - bias [:, None ]
573
+
574
+ padded_len = (seq_len + 7 ) // 8 * 8
575
+ bias_tensor = torch .empty (
576
+ 1 ,
577
+ num_heads ,
578
+ seq_len ,
579
+ padded_len ,
580
+ device = alibi_slopes .device ,
581
+ dtype = dtype ,
582
+ )[:, :, :, :seq_len ].copy_ (bias )
583
+ bias_tensor .mul_ (alibi_slopes [:, None , None ])
584
+ attn_biases .append (LowerTriangularMaskWithTensorBias (bias_tensor ))
585
+
586
+ return attn_biases
587
+
588
+
543
589
def _make_metadata_tensors (
544
590
seq_lens : Optional [list [int ]],
545
591
context_lens : Optional [list [int ]],
0 commit comments