@@ -17,16 +17,18 @@ def attention_sinks_kernel(
1717 sliding_window_size ,
1818 q_head_num : tl .constexpr ,
1919 k_head_num : tl .constexpr ,
20+ block_group_size : tl .constexpr ,
2021 D : tl .constexpr ,
2122 PAGE_SIZE : tl .constexpr ,
2223 MAX_BLOCKS : tl .constexpr ,
23- sync_space ,
2424):
25- i_s , i_qh = tl .program_id (0 ), tl .program_id (1 )
26- i_kvh = i_qh // (q_head_num // k_head_num )
25+ i_s , i_gh = tl .program_id (0 ), tl .program_id (1 )
26+ i_kvh = i_gh * block_group_size // (q_head_num // k_head_num )
2727
2828 kv_seq_len = tl .load (kv_seq_lens + i_s )
2929 page_num = tl .cdiv (kv_seq_len , PAGE_SIZE )
30+ page_num = min (page_num , MAX_BLOCKS )
31+
3032 start_page_num = 0
3133 start_kv_len = 0
3234 if sliding_window_size != - 1 and kv_seq_len > sliding_window_size :
@@ -36,16 +38,15 @@ def attention_sinks_kernel(
3638 cur_page_start = i_s * MAX_BLOCKS
3739 offset_page = tl .arange (0 , PAGE_SIZE )
3840 offset_d = tl .arange (0 , D )
39- Br : tl .constexpr = 1
41+ Br : tl .constexpr = block_group_size
4042
41- sink = tl .load (sinks + i_qh )
43+ sink = tl .load (sinks + i_gh * block_group_size + tl . arange ( 0 , Br ) )
4244 history_max = tl .zeros ([Br ], dtype = tl .float32 ) + sink
4345 l = tl .zeros ([Br ], dtype = tl .float32 )
4446 acc = tl .zeros ([Br , D ], dtype = tl .float32 )
4547
46- offset_q = i_qh * D + offset_d
47- offset_seq = (tl .arange (0 , Br ) + i_s ) * D * q_head_num
48- q = tl .load (query + offset_seq [:, None ] + offset_q [None , :]).to (tl .float32 )
48+ offset_seq = (i_s * q_head_num + i_gh * block_group_size + tl .arange (0 , Br )) * D
49+ q = tl .load (query + offset_seq [:, None ] + offset_d [None , :]).to (tl .float32 )
4950
5051 for page_idx in range (start_page_num , page_num ):
5152 block_idx = tl .load (block_tables + cur_page_start + page_idx )
@@ -75,17 +76,13 @@ def attention_sinks_kernel(
7576 l = l * re_scale + tl .sum (p_exp , 1 )
7677 acc = acc * re_scale [:, None ] + tl .dot (p_exp .to (v .dtype ), v )
7778
78- # The purpose of this store is to insert synchronization within the loop.
79- # Do not remove this store until triton solves the synchronization problem,
80- # as doing so may lead to accuracy problem.
81- tl .store (sync_space + tl .arange (0 , Br ), new_e_max )
8279 history_max = new_e_max
8380
8481 sink = tl .math .exp (sink - history_max )
8582 l = l + sink
8683 acc = acc / l [:, None ]
8784 tl .store (
88- attn_out + offset_seq [:, None ] + offset_q [None , :],
85+ attn_out + offset_seq [:, None ] + offset_d [None , :],
8986 acc .to (attn_out .type .element_ty ),
9087 )
9188
@@ -106,23 +103,17 @@ def attention_sinks_triton(
106103 D = query .shape [- 1 ] // q_head_num
107104 PAGE_SIZE = k_cache .shape [1 ]
108105 v_head_dim = v_cache .shape [- 1 ]
106+
107+ group_block_size = min (q_head_num // k_head_num , 16 )
108+ group_block_num = q_head_num // group_block_size
109+
109110 attn_output = torch .zeros (
110111 (S , q_head_num , v_head_dim ),
111112 dtype = query .dtype ,
112113 device = query .device ,
113114 )
114- sync_space = torch .empty (
115- (PAGE_SIZE ,),
116- dtype = torch .float32 ,
117- device = query .device ,
118- )
119-
120- if isinstance (context_lens , list ):
121- context_lens = torch .tensor (context_lens , device = query .device )
122- else :
123- context_lens = context_lens .to (query .device )
124115
125- grid = [S , q_head_num ]
116+ grid = [S , group_block_num ]
126117 attention_sinks_kernel [grid ](
127118 query ,
128119 k_cache ,
@@ -135,10 +126,10 @@ def attention_sinks_triton(
135126 sliding_window_size ,
136127 q_head_num ,
137128 k_head_num ,
129+ group_block_size ,
138130 D ,
139131 PAGE_SIZE ,
140132 block_tables .stride (0 ),
141- sync_space ,
142133 )
143134
144135 return attn_output .reshape (- 1 , q_head_num * v_head_dim )
@@ -151,6 +142,7 @@ def attention_sinks_prefill_kernel(
151142 v_cache ,
152143 sinks ,
153144 attn_out ,
145+ cum_seq_lens ,
154146 block_tables ,
155147 kv_seq_lens ,
156148 scale ,
@@ -160,98 +152,98 @@ def attention_sinks_prefill_kernel(
160152 D : tl .constexpr ,
161153 PAGE_SIZE : tl .constexpr ,
162154 MAX_BLOCKS : tl .constexpr ,
163- B : tl .constexpr ,
164- BS : tl .constexpr ,
165- sync_space ,
166155):
167- i_ns , i_qh = tl .program_id (0 ), tl .program_id (1 )
156+ i_b , i_qh = tl .program_id (0 ), tl .program_id (1 )
168157 i_kvh = i_qh // (q_head_num // k_head_num )
169158
170- for i_bs in range (BS ):
171- i_s = i_ns * BS + i_bs
172-
173- i_pos = - 1
174- kv_seq_len = i_s
175-
176- for i in range (B ):
177- tmp_seq_len = tl .load (kv_seq_lens + i )
178- if kv_seq_len >= tmp_seq_len and i_pos == - 1 :
179- kv_seq_len -= tmp_seq_len
180- elif i_pos == - 1 :
181- i_pos = i
182-
183- if i_pos != - 1 :
184- kv_seq_len += 1
185-
186- page_num = tl .cdiv (kv_seq_len , PAGE_SIZE )
187- start_page_num = 0
188- start_kv_len = 0
189- if sliding_window_size != - 1 and kv_seq_len > sliding_window_size :
190- start_kv_len = (kv_seq_len - sliding_window_size ).to (tl .int32 )
191- start_page_num = start_kv_len // PAGE_SIZE
192-
193- cur_page_start = i_pos * MAX_BLOCKS
194- offset_page = tl .arange (0 , PAGE_SIZE )
195- offset_d = tl .arange (0 , D )
196- Br : tl .constexpr = 1
197-
198- sink = tl .load (sinks + i_qh )
199- history_max = tl .zeros ([Br ], dtype = tl .float32 ) + sink
200- l = tl .zeros ([Br ], dtype = tl .float32 )
201- acc = tl .zeros ([Br , D ], dtype = tl .float32 )
202-
203- offset_q = i_qh * D + offset_d
204- offset_seq = (tl .arange (0 , Br ) + i_s ) * D * q_head_num
205- q = tl .load (query + offset_seq [:, None ] + offset_q [None , :]).to (tl .float32 )
206-
207- for page_idx in range (start_page_num , page_num ):
208- block_idx = tl .load (block_tables + cur_page_start + page_idx )
209- mask_page = ((page_idx * PAGE_SIZE + offset_page ) < kv_seq_len ) & (
210- (page_idx * PAGE_SIZE + offset_page ) >= start_kv_len
211- )
212-
213- offset_k = (
214- block_idx * PAGE_SIZE * k_head_num * D
215- + offset_page [:, None ] * k_head_num * D
216- + i_kvh * D
217- + offset_d [None , :]
218- )
219- k = tl .load (k_cache + offset_k , mask = mask_page [:, None ]).to (tl .float32 )
220- v = tl .load (v_cache + offset_k , mask = mask_page [:, None ]).to (tl .float32 )
221-
222- k = tl .trans (k , (1 , 0 ))
223- qk = tl .dot (q , k )
224- qk = qk * scale
225- qk = tl .where (mask_page [None , :], qk , float ("-inf" ))
226-
227- new_e_max = tl .maximum (tl .max (qk , 1 ), history_max )
228- re_scale = tl .exp (history_max - new_e_max )
229- p_exp = tl .exp (qk - new_e_max [:, None ])
230-
231- # Online softmax update
232- l = l * re_scale + tl .sum (p_exp , 1 )
233- acc = acc * re_scale [:, None ] + tl .dot (p_exp .to (v .dtype ), v )
234-
235- # The purpose of this store is to insert synchronization within the loop.
236- # Do not remove this store until triton solves the synchronization problem,
237- # as doing so may lead to accuracy problem.
238- tl .store (sync_space + tl .arange (0 , Br ), new_e_max )
239- history_max = new_e_max
240-
241- sink = tl .math .exp (sink - history_max )
242- l = l + sink
243- acc = acc / l [:, None ]
244- tl .store (
245- attn_out + offset_seq [:, None ] + offset_q [None , :],
246- acc .to (attn_out .type .element_ty ),
159+ q_end_offset = tl .load (cum_seq_lens + i_b )
160+ q_start_offset = 0
161+ q_start_offset = q_start_offset .to (q_end_offset .dtype )
162+ if i_b > 0 :
163+ q_start_offset = tl .load (cum_seq_lens + i_b - 1 )
164+
165+ Br : tl .constexpr = 16
166+
167+ for i_s in range (q_start_offset , q_end_offset , Br ):
168+ kv_seq_len = tl .load (kv_seq_lens + i_b ) + i_s - q_end_offset + 1
169+
170+ page_num = tl .cdiv (kv_seq_len + Br , PAGE_SIZE )
171+ page_num = min (page_num , MAX_BLOCKS )
172+
173+ kv_seq_len_block = kv_seq_len + tl .arange (0 , Br )
174+ start_kv_len_block = tl .zeros ([Br ], dtype = tl .int32 )
175+
176+ start_page_num = 0
177+ if sliding_window_size != - 1 :
178+ start_kv_len = max ((kv_seq_len - sliding_window_size ).to (tl .int32 ), 0 )
179+ start_page_num = start_kv_len // PAGE_SIZE
180+ start_kv_len_block = max (
181+ (kv_seq_len_block - sliding_window_size ).to (tl .int32 ), 0
182+ )
183+
184+ cur_page_start = i_b * MAX_BLOCKS
185+ offset_page = tl .arange (0 , PAGE_SIZE )
186+ offset_d = tl .arange (0 , D )
187+
188+ sink = tl .load (sinks + i_qh )
189+ history_max = tl .zeros ([Br ], dtype = tl .float32 ) + sink
190+ l = tl .zeros ([Br ], dtype = tl .float32 )
191+ acc = tl .zeros ([Br , D ], dtype = tl .float32 )
192+
193+ offset_q = i_qh * D + offset_d
194+ offset_seq = (tl .arange (0 , Br ) + i_s ) * D * q_head_num
195+ mask_seq = (tl .arange (0 , Br ) + i_s ) < q_end_offset
196+ q = tl .load (
197+ query + offset_seq [:, None ] + offset_q [None , :], mask = mask_seq [:, None ]
198+ ).to (tl .float32 )
199+
200+ for page_idx in range (start_page_num , page_num ):
201+ block_idx = tl .load (block_tables + cur_page_start + page_idx )
202+ cur_offset_page = page_idx * PAGE_SIZE + offset_page
203+ mask_page = (cur_offset_page [None , :] < kv_seq_len_block [:, None ]) & (
204+ cur_offset_page [None , :] >= start_kv_len_block [:, None ]
205+ )
206+
207+ offset_k = (
208+ block_idx * PAGE_SIZE * k_head_num * D
209+ + offset_page [:, None ] * k_head_num * D
210+ + i_kvh * D
211+ + offset_d [None , :]
247212 )
213+ k = tl .load (k_cache + offset_k ).to (tl .float32 )
214+ v = tl .load (v_cache + offset_k ).to (tl .float32 )
215+
216+ k = tl .trans (k , (1 , 0 ))
217+ qk = tl .dot (q , k )
218+ qk = qk * scale
219+ qk = tl .where (mask_page , qk , float ("-inf" ))
220+
221+ new_e_max = tl .maximum (tl .max (qk , 1 ), history_max )
222+ re_scale = tl .exp (history_max - new_e_max )
223+ p_exp = tl .exp (qk - new_e_max [:, None ])
224+
225+ # Online softmax update
226+ l = l * re_scale + tl .sum (p_exp , 1 )
227+ acc = acc * re_scale [:, None ] + tl .dot (p_exp .to (v .dtype ), v )
228+
229+ history_max = new_e_max
230+
231+ sink = tl .math .exp (sink - history_max )
232+ l = l + sink
233+ acc = acc / l [:, None ]
234+ tl .store (
235+ attn_out + offset_seq [:, None ] + offset_q [None , :],
236+ acc .to (attn_out .type .element_ty ),
237+ mask = mask_seq [:, None ],
238+ )
248239
249240
250241def attention_sinks_prefill_triton (
251242 query ,
252243 k_cache ,
253244 v_cache ,
254245 sinks ,
246+ seq_lens ,
255247 block_tables ,
256248 context_lens ,
257249 scale ,
@@ -260,10 +252,6 @@ def attention_sinks_prefill_triton(
260252 k_head_num ,
261253):
262254 S = query .shape [0 ]
263- kernel_num = get_device_properties ()[0 ]
264- BS = triton .cdiv (S , kernel_num )
265- NS = triton .cdiv (S , BS )
266-
267255 D = query .shape [- 1 ] // q_head_num
268256 PAGE_SIZE = k_cache .shape [1 ]
269257 v_head_dim = v_cache .shape [- 1 ]
@@ -272,25 +260,18 @@ def attention_sinks_prefill_triton(
272260 dtype = query .dtype ,
273261 device = query .device ,
274262 )
275- sync_space = torch .empty (
276- (PAGE_SIZE ,),
277- dtype = torch .float32 ,
278- device = query .device ,
279- )
280263
281- if isinstance (context_lens , list ):
282- context_lens = torch .tensor (context_lens , device = query .device )
283- else :
284- context_lens = context_lens .to (query .device )
285- B = context_lens .shape [0 ]
264+ cum_seq_lens = torch .cumsum (seq_lens , dim = 0 )
265+ B = seq_lens .shape [0 ]
286266
287- grid = [NS , q_head_num ]
267+ grid = [B , q_head_num ]
288268 attention_sinks_prefill_kernel [grid ](
289269 query ,
290270 k_cache ,
291271 v_cache ,
292272 sinks ,
293273 attn_output ,
274+ cum_seq_lens ,
294275 block_tables ,
295276 context_lens ,
296277 scale ,
@@ -300,9 +281,6 @@ def attention_sinks_prefill_triton(
300281 D ,
301282 PAGE_SIZE ,
302283 block_tables .stride (0 ),
303- B ,
304- BS ,
305- sync_space ,
306284 )
307285
308286 return attn_output .reshape (- 1 , q_head_num * v_head_dim )
0 commit comments