@@ -75,6 +75,7 @@ def kernel_unified_attention_2d(
75
75
USE_ALIBI_SLOPES : tl .constexpr , # bool
76
76
USE_QQ_BIAS : tl .constexpr , # bool
77
77
USE_SOFTCAP : tl .constexpr , # bool
78
+ USE_SINKS : tl .constexpr , # bool
78
79
SLIDING_WINDOW : tl .constexpr , # int
79
80
stride_k_cache_0 : tl .int64 , # int
80
81
stride_k_cache_1 : tl .int64 , # int
@@ -132,7 +133,7 @@ def kernel_unified_attention_2d(
132
133
133
134
block_table_offset = seq_idx * block_table_stride
134
135
135
- if sink_ptr is None :
136
+ if not USE_SINKS :
136
137
M = tl .full ([BLOCK_M ], float ("-inf" ), dtype = tl .float32 )
137
138
else :
138
139
M = tl .load (
@@ -322,6 +323,7 @@ def kernel_unified_attention_3d(
322
323
USE_ALIBI_SLOPES : tl .constexpr , # bool
323
324
USE_QQ_BIAS : tl .constexpr , # bool
324
325
USE_SOFTCAP : tl .constexpr , # bool
326
+ USE_SINKS : tl .constexpr , # bool
325
327
SLIDING_WINDOW : tl .constexpr , # int
326
328
stride_k_cache_0 : tl .int64 , # int
327
329
stride_k_cache_1 : tl .int64 , # int
@@ -393,14 +395,17 @@ def kernel_unified_attention_3d(
393
395
394
396
block_table_offset = seq_idx * block_table_stride
395
397
396
- if sink_ptr is None or segm_idx != 0 :
397
- M = tl .full ([BLOCK_M ], float ("-inf" ), dtype = tl .float32 )
398
+ if USE_SINKS :
399
+ if segm_idx == 0 :
400
+ M = tl .load (
401
+ sink_ptr + query_offset_1 ,
402
+ mask = query_mask_1 ,
403
+ other = float ("-inf" ),
404
+ ).to (dtype = tl .float32 )
405
+ else :
406
+ M = tl .full ([BLOCK_M ], float ("-inf" ), dtype = tl .float32 )
398
407
else :
399
- M = tl .load (
400
- sink_ptr + query_offset_1 ,
401
- mask = query_mask_1 ,
402
- other = float ("-inf" ),
403
- ).to (dtype = tl .float32 )
408
+ M = tl .full ([BLOCK_M ], float ("-inf" ), dtype = tl .float32 )
404
409
405
410
L = tl .full ([BLOCK_M ], 1.0 , dtype = tl .float32 )
406
411
acc = tl .zeros ([BLOCK_M , HEAD_SIZE_PADDED ], dtype = tl .float32 )
@@ -716,6 +721,7 @@ def unified_attention(
716
721
USE_ALIBI_SLOPES = use_alibi_slopes ,
717
722
USE_QQ_BIAS = use_qq_bias ,
718
723
USE_SOFTCAP = (softcap > 0 ),
724
+ USE_SINKS = (sinks is not None ),
719
725
SLIDING_WINDOW = (1 + window_size [0 ]),
720
726
stride_k_cache_0 = k .stride (0 ),
721
727
stride_k_cache_1 = k .stride (1 ),
@@ -787,6 +793,7 @@ def unified_attention(
787
793
USE_ALIBI_SLOPES = use_alibi_slopes ,
788
794
USE_QQ_BIAS = use_qq_bias ,
789
795
USE_SOFTCAP = (softcap > 0 ),
796
+ USE_SINKS = (sinks is not None ),
790
797
SLIDING_WINDOW = (1 + window_size [0 ]),
791
798
stride_k_cache_0 = k .stride (0 ),
792
799
stride_k_cache_1 = k .stride (1 ),
0 commit comments