Skip to content

Commit 4a6b72c

Browse files
[BugFix] Fix triton compile error in kernel_unified_attention_2/3d caused by attention sinks (#22368)
Signed-off-by: LucasWilkinson <[email protected]>
1 parent b4b9813 commit 4a6b72c

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

vllm/attention/ops/triton_unified_attention.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def kernel_unified_attention_2d(
7575
USE_ALIBI_SLOPES: tl.constexpr, # bool
7676
USE_QQ_BIAS: tl.constexpr, # bool
7777
USE_SOFTCAP: tl.constexpr, # bool
78+
USE_SINKS: tl.constexpr, # bool
7879
SLIDING_WINDOW: tl.constexpr, # int
7980
stride_k_cache_0: tl.int64, # int
8081
stride_k_cache_1: tl.int64, # int
@@ -132,7 +133,7 @@ def kernel_unified_attention_2d(
132133

133134
block_table_offset = seq_idx * block_table_stride
134135

135-
if sink_ptr is None:
136+
if not USE_SINKS:
136137
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
137138
else:
138139
M = tl.load(
@@ -322,6 +323,7 @@ def kernel_unified_attention_3d(
322323
USE_ALIBI_SLOPES: tl.constexpr, # bool
323324
USE_QQ_BIAS: tl.constexpr, # bool
324325
USE_SOFTCAP: tl.constexpr, # bool
326+
USE_SINKS: tl.constexpr, # bool
325327
SLIDING_WINDOW: tl.constexpr, # int
326328
stride_k_cache_0: tl.int64, # int
327329
stride_k_cache_1: tl.int64, # int
@@ -393,14 +395,17 @@ def kernel_unified_attention_3d(
393395

394396
block_table_offset = seq_idx * block_table_stride
395397

396-
if sink_ptr is None or segm_idx != 0:
397-
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
398+
if USE_SINKS:
399+
if segm_idx == 0:
400+
M = tl.load(
401+
sink_ptr + query_offset_1,
402+
mask=query_mask_1,
403+
other=float("-inf"),
404+
).to(dtype=tl.float32)
405+
else:
406+
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
398407
else:
399-
M = tl.load(
400-
sink_ptr + query_offset_1,
401-
mask=query_mask_1,
402-
other=float("-inf"),
403-
).to(dtype=tl.float32)
408+
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
404409

405410
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
406411
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
@@ -716,6 +721,7 @@ def unified_attention(
716721
USE_ALIBI_SLOPES=use_alibi_slopes,
717722
USE_QQ_BIAS=use_qq_bias,
718723
USE_SOFTCAP=(softcap > 0),
724+
USE_SINKS=(sinks is not None),
719725
SLIDING_WINDOW=(1 + window_size[0]),
720726
stride_k_cache_0=k.stride(0),
721727
stride_k_cache_1=k.stride(1),
@@ -787,6 +793,7 @@ def unified_attention(
787793
USE_ALIBI_SLOPES=use_alibi_slopes,
788794
USE_QQ_BIAS=use_qq_bias,
789795
USE_SOFTCAP=(softcap > 0),
796+
USE_SINKS=(sinks is not None),
790797
SLIDING_WINDOW=(1 + window_size[0]),
791798
stride_k_cache_0=k.stride(0),
792799
stride_k_cache_1=k.stride(1),

0 commit comments

Comments
 (0)