Skip to content

Commit 6e20924

Browse files
WoosukKwonLiuXiaoxuanPKUsimon-moheheda12345hongxiayang
authored
Add attention sink in attention backends (#22320)
Signed-off-by: Woosuk Kwon <[email protected]> Co-authored-by: LiuXiaoxuanPKU <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Chen Zhang <[email protected]> Co-authored-by: Hongxia Yang <[email protected]> Co-authored-by: Minseok Lee <[email protected]> Co-authored-by: Yongye Zhu <[email protected]>
1 parent dd16bdc commit 6e20924

File tree

7 files changed

+176
-45
lines changed

7 files changed

+176
-45
lines changed

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def kernel_paged_attention_2d(
2828
query_ptr, # [num_tokens, num_query_heads, head_size]
2929
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
3030
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
31+
sink_ptr, # [num_query_heads]
3132
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
3233
seq_lens_ptr, # [num_seqs]
3334
alibi_slopes_ptr, # [num_query_heads]
@@ -95,7 +96,17 @@ def kernel_paged_attention_2d(
9596

9697
block_table_offset = seq_idx * block_table_stride
9798

98-
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
99+
if sink_ptr is None:
100+
M = tl.full([num_queries_per_kv_padded],
101+
float("-inf"),
102+
dtype=tl.float32)
103+
else:
104+
M = tl.load(
105+
sink_ptr + query_head_idx,
106+
mask=head_mask,
107+
other=float("-inf"),
108+
).to(dtype=tl.float32)
109+
99110
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
100111
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
101112
dtype=tl.float32)
@@ -223,6 +234,8 @@ def chunked_prefill_paged_decode(
223234
alibi_slopes=None,
224235
sliding_window=None,
225236
sm_scale=None,
237+
# Optional tensor for sinks
238+
sinks=None,
226239
):
227240

228241
if sm_scale is None:
@@ -253,6 +266,7 @@ def chunked_prefill_paged_decode(
253266
sliding_window=sliding_window,
254267
sm_scale=sm_scale,
255268
skip_decode=True,
269+
sinks=sinks,
256270
)
257271

258272
block_size = value_cache.shape[3]
@@ -281,11 +295,17 @@ def chunked_prefill_paged_decode(
281295
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
282296
16)
283297

284-
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,
285-
block_size,
286-
num_queries_per_kv,
287-
max_seq_len, sliding_window,
288-
kv_cache_dtype, alibi_slopes)
298+
use_custom = use_rocm_custom_paged_attention(
299+
query.dtype,
300+
head_size,
301+
block_size,
302+
num_queries_per_kv,
303+
max_seq_len,
304+
sliding_window,
305+
kv_cache_dtype,
306+
alibi_slopes,
307+
sinks,
308+
)
289309
if use_custom:
290310
_PARTITION_SIZE_ROCM = 256
291311
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
@@ -334,6 +354,7 @@ def chunked_prefill_paged_decode(
334354
query_ptr=query,
335355
key_cache_ptr=key_cache,
336356
value_cache_ptr=value_cache,
357+
sink_ptr=sinks,
337358
block_tables_ptr=block_table,
338359
seq_lens_ptr=seq_lens,
339360
alibi_slopes_ptr=alibi_slopes,

vllm/attention/ops/prefix_prefill.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def _fwd_kernel(Q,
3838
V,
3939
K_cache,
4040
V_cache,
41+
sink_ptr,
4142
B_Loc,
4243
sm_scale,
4344
k_scale,
@@ -126,7 +127,15 @@ def _fwd_kernel(Q,
126127
other=0.0) # [M,D]
127128

128129
# initialize pointer to m and l
129-
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
130+
if sink_ptr is None:
131+
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
132+
else:
133+
m_i = tl.load(
134+
sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64),
135+
mask=(offs_m < cur_batch_query_len),
136+
other=float("-inf"),
137+
).to(dtype=tl.float32)
138+
130139
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
131140
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
132141

@@ -732,7 +741,8 @@ def context_attention_fwd(q,
732741
alibi_slopes=None,
733742
sliding_window=None,
734743
sm_scale=None,
735-
skip_decode=False):
744+
skip_decode=False,
745+
sinks=None):
736746

737747
q_dtype_is_f32 = q.dtype is torch.float32
738748

@@ -781,6 +791,7 @@ def context_attention_fwd(q,
781791
sliding_window = 0
782792

783793
if alibi_slopes is not None:
794+
assert sinks is None, "Sinks arg is not supported with alibi"
784795
# need to reduce num. blocks when using fp32
785796
# due to increased use of GPU shared memory
786797
# if q.dtype is torch.float32:
@@ -843,7 +854,7 @@ def context_attention_fwd(q,
843854
max_seq_len = 0 if max_seq_len is None else max_seq_len
844855
extra_kargs = {}
845856
if current_platform.is_rocm():
846-
extra_kargs = {"kpack": 2, "waves_per_eu": 2}
857+
extra_kargs = {"kpack": 1, "waves_per_eu": 2}
847858

848859
grid = lambda META: (batch, head,
849860
triton.cdiv(max_input_len, META["BLOCK_M"]))
@@ -853,6 +864,7 @@ def context_attention_fwd(q,
853864
v,
854865
k_cache,
855866
v_cache,
867+
sinks,
856868
b_loc,
857869
sm_scale,
858870
k_scale,

vllm/attention/ops/triton_unified_attention.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def kernel_unified_attention_2d(
5252
query_ptr, # [num_tokens, num_query_heads, head_size]
5353
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
5454
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
55+
sink_ptr, # [num_query_heads]
5556
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
5657
seq_lens_ptr, # [num_seqs]
5758
alibi_slopes_ptr, # [num_query_heads]
@@ -131,7 +132,15 @@ def kernel_unified_attention_2d(
131132

132133
block_table_offset = seq_idx * block_table_stride
133134

134-
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
135+
if sink_ptr is None:
136+
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
137+
else:
138+
M = tl.load(
139+
sink_ptr + query_offset_1,
140+
mask=query_mask_1,
141+
other=float("-inf"),
142+
).to(dtype=tl.float32)
143+
135144
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
136145
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
137146

@@ -292,6 +301,7 @@ def kernel_unified_attention_3d(
292301
query_ptr, # [num_tokens, num_query_heads, head_size]
293302
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
294303
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
304+
sink_ptr, # [num_query_heads]
295305
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
296306
seq_lens_ptr, # [num_seqs]
297307
alibi_slopes_ptr, # [num_query_heads]
@@ -383,7 +393,15 @@ def kernel_unified_attention_3d(
383393

384394
block_table_offset = seq_idx * block_table_stride
385395

386-
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
396+
if sink_ptr is None or segm_idx != 0:
397+
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
398+
else:
399+
M = tl.load(
400+
sink_ptr + query_offset_1,
401+
mask=query_mask_1,
402+
other=float("-inf"),
403+
).to(dtype=tl.float32)
404+
387405
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
388406
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
389407

@@ -627,6 +645,8 @@ def unified_attention(
627645
v_descale,
628646
alibi_slopes=None,
629647
qq_bias=None,
648+
# Optional tensor for sinks
649+
sinks=None,
630650
):
631651
assert causal, "Only causal attention is supported"
632652
assert q_descale is None, "Q scales not supported"
@@ -635,6 +655,10 @@ def unified_attention(
635655
assert q.element_size() >= 2 or block_size >= 32, \
636656
"Block size must be at least 32 for fp8"
637657

658+
if sinks is not None:
659+
assert sinks.shape[0] == q.shape[1], \
660+
"Sinks must be num_query_heads size"
661+
638662
use_alibi_slopes = alibi_slopes is not None
639663
use_qq_bias = qq_bias is not None
640664

@@ -669,6 +693,7 @@ def unified_attention(
669693
query_ptr=q,
670694
key_cache_ptr=k,
671695
value_cache_ptr=v,
696+
sink_ptr=sinks,
672697
block_tables_ptr=block_table,
673698
seq_lens_ptr=seqused_k,
674699
alibi_slopes_ptr=alibi_slopes,
@@ -741,6 +766,7 @@ def unified_attention(
741766
query_ptr=q,
742767
key_cache_ptr=k,
743768
value_cache_ptr=v,
769+
sink_ptr=sinks,
744770
block_tables_ptr=block_table,
745771
seq_lens_ptr=seqused_k,
746772
alibi_slopes_ptr=alibi_slopes,

vllm/envs.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
LD_LIBRARY_PATH: Optional[str] = None
1818
VLLM_USE_TRITON_FLASH_ATTN: bool = True
1919
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
20+
VLLM_USE_AITER_UNIFIED_ATTENTION: bool = False
2021
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
2122
LOCAL_RANK: int = 0
2223
CUDA_VISIBLE_DEVICES: Optional[str] = None
@@ -151,6 +152,8 @@
151152
VLLM_LOOPBACK_IP: str = ""
152153
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
153154
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
155+
VLLM_USE_TRTLLM_CONTEXT_ATTENTION: bool = False
156+
VLLM_USE_TRTLLM_DECODE_ATTENTION: bool = False
154157

155158

156159
def get_default_cache_root():
@@ -326,6 +329,12 @@ def get_vllm_port() -> Optional[int]:
326329
(os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in
327330
("true", "1")),
328331

332+
# Use AITER triton unified attention for V1 attention
333+
"VLLM_USE_AITER_UNIFIED_ATTENTION":
334+
lambda:
335+
(os.getenv("VLLM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in
336+
("true", "1")),
337+
329338
# Force vllm to use a specific flash-attention version (2 or 3), only valid
330339
# when using the flash-attention backend.
331340
"VLLM_FLASH_ATTN_VERSION":
@@ -1022,9 +1031,13 @@ def get_vllm_port() -> Optional[int]:
10221031
"VLLM_USE_CUDNN_PREFILL":
10231032
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
10241033

1025-
# If set to 1, use the TRTLLM Attention backend in flashinfer.
1026-
"VLLM_USE_TRTLLM_ATTENTION":
1027-
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
1034+
# If set to 1, use the TRTLLM Context Attention backend in flashinfer.
1035+
"VLLM_USE_TRTLLM_CONTEXT_ATTENTION":
1036+
lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_CONTEXT_ATTENTION", "0"))),
1037+
1038+
# If set to 1, use the TRTLLM Decode Attention backend in flashinfer.
1039+
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
1040+
lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", "0"))),
10281041

10291042
# Controls garbage collection during CUDA graph capture.
10301043
# If set to 0 (default), enables GC freezing to speed up capture time.

vllm/v1/attention/backends/flash_attn.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def __init__(
373373
logits_soft_cap: Optional[float] = None,
374374
attn_type: AttentionType = AttentionType.DECODER,
375375
kv_sharing_target_layer_name: Optional[str] = None,
376+
sinks: Optional[torch.Tensor] = None,
376377
) -> None:
377378
self.num_heads = num_heads
378379
self.head_size = head_size
@@ -410,6 +411,14 @@ def __init__(
410411
raise NotImplementedError(
411412
"FlashAttention does not support fp8 kv-cache on this device.")
412413

414+
self.sinks = sinks
415+
if self.sinks is not None:
416+
assert self.vllm_flash_attn_version == 3, (
417+
"Sinks are only supported in FlashAttention 3")
418+
assert self.sinks.shape[0] == num_heads, (
419+
"Sinks must have the same number of heads as the number of "
420+
"heads in the layer")
421+
413422
def forward(
414423
self,
415424
layer: torch.nn.Module,
@@ -534,6 +543,7 @@ def forward(
534543
k_descale=layer._k_scale.expand(descale_shape),
535544
v_descale=layer._v_scale.expand(descale_shape),
536545
num_splits=attn_metadata.max_num_splits,
546+
s_aux=self.sinks,
537547
)
538548
return output
539549

0 commit comments

Comments
 (0)