Skip to content

Commit 5402f33

Browse files
authored
optimize sinks attention (#260)
1 parent d756fa4 commit 5402f33

File tree

1 file changed

+102
-124
lines changed

1 file changed

+102
-124
lines changed

python/sgl_kernel_npu/sgl_kernel_npu/attention/sinks_attention.py

Lines changed: 102 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,18 @@ def attention_sinks_kernel(
1717
sliding_window_size,
1818
q_head_num: tl.constexpr,
1919
k_head_num: tl.constexpr,
20+
block_group_size: tl.constexpr,
2021
D: tl.constexpr,
2122
PAGE_SIZE: tl.constexpr,
2223
MAX_BLOCKS: tl.constexpr,
23-
sync_space,
2424
):
25-
i_s, i_qh = tl.program_id(0), tl.program_id(1)
26-
i_kvh = i_qh // (q_head_num // k_head_num)
25+
i_s, i_gh = tl.program_id(0), tl.program_id(1)
26+
i_kvh = i_gh * block_group_size // (q_head_num // k_head_num)
2727

2828
kv_seq_len = tl.load(kv_seq_lens + i_s)
2929
page_num = tl.cdiv(kv_seq_len, PAGE_SIZE)
30+
page_num = min(page_num, MAX_BLOCKS)
31+
3032
start_page_num = 0
3133
start_kv_len = 0
3234
if sliding_window_size != -1 and kv_seq_len > sliding_window_size:
@@ -36,16 +38,15 @@ def attention_sinks_kernel(
3638
cur_page_start = i_s * MAX_BLOCKS
3739
offset_page = tl.arange(0, PAGE_SIZE)
3840
offset_d = tl.arange(0, D)
39-
Br: tl.constexpr = 1
41+
Br: tl.constexpr = block_group_size
4042

41-
sink = tl.load(sinks + i_qh)
43+
sink = tl.load(sinks + i_gh * block_group_size + tl.arange(0, Br))
4244
history_max = tl.zeros([Br], dtype=tl.float32) + sink
4345
l = tl.zeros([Br], dtype=tl.float32)
4446
acc = tl.zeros([Br, D], dtype=tl.float32)
4547

46-
offset_q = i_qh * D + offset_d
47-
offset_seq = (tl.arange(0, Br) + i_s) * D * q_head_num
48-
q = tl.load(query + offset_seq[:, None] + offset_q[None, :]).to(tl.float32)
48+
offset_seq = (i_s * q_head_num + i_gh * block_group_size + tl.arange(0, Br)) * D
49+
q = tl.load(query + offset_seq[:, None] + offset_d[None, :]).to(tl.float32)
4950

5051
for page_idx in range(start_page_num, page_num):
5152
block_idx = tl.load(block_tables + cur_page_start + page_idx)
@@ -75,17 +76,13 @@ def attention_sinks_kernel(
7576
l = l * re_scale + tl.sum(p_exp, 1)
7677
acc = acc * re_scale[:, None] + tl.dot(p_exp.to(v.dtype), v)
7778

78-
# The purpose of this store is to insert synchronization within the loop.
79-
# Do not remove this store until triton solves the synchronization problem,
80-
# as doing so may lead to accuracy problem.
81-
tl.store(sync_space + tl.arange(0, Br), new_e_max)
8279
history_max = new_e_max
8380

8481
sink = tl.math.exp(sink - history_max)
8582
l = l + sink
8683
acc = acc / l[:, None]
8784
tl.store(
88-
attn_out + offset_seq[:, None] + offset_q[None, :],
85+
attn_out + offset_seq[:, None] + offset_d[None, :],
8986
acc.to(attn_out.type.element_ty),
9087
)
9188

@@ -106,23 +103,17 @@ def attention_sinks_triton(
106103
D = query.shape[-1] // q_head_num
107104
PAGE_SIZE = k_cache.shape[1]
108105
v_head_dim = v_cache.shape[-1]
106+
107+
group_block_size = min(q_head_num // k_head_num, 16)
108+
group_block_num = q_head_num // group_block_size
109+
109110
attn_output = torch.zeros(
110111
(S, q_head_num, v_head_dim),
111112
dtype=query.dtype,
112113
device=query.device,
113114
)
114-
sync_space = torch.empty(
115-
(PAGE_SIZE,),
116-
dtype=torch.float32,
117-
device=query.device,
118-
)
119-
120-
if isinstance(context_lens, list):
121-
context_lens = torch.tensor(context_lens, device=query.device)
122-
else:
123-
context_lens = context_lens.to(query.device)
124115

125-
grid = [S, q_head_num]
116+
grid = [S, group_block_num]
126117
attention_sinks_kernel[grid](
127118
query,
128119
k_cache,
@@ -135,10 +126,10 @@ def attention_sinks_triton(
135126
sliding_window_size,
136127
q_head_num,
137128
k_head_num,
129+
group_block_size,
138130
D,
139131
PAGE_SIZE,
140132
block_tables.stride(0),
141-
sync_space,
142133
)
143134

144135
return attn_output.reshape(-1, q_head_num * v_head_dim)
@@ -151,6 +142,7 @@ def attention_sinks_prefill_kernel(
151142
v_cache,
152143
sinks,
153144
attn_out,
145+
cum_seq_lens,
154146
block_tables,
155147
kv_seq_lens,
156148
scale,
@@ -160,98 +152,98 @@ def attention_sinks_prefill_kernel(
160152
D: tl.constexpr,
161153
PAGE_SIZE: tl.constexpr,
162154
MAX_BLOCKS: tl.constexpr,
163-
B: tl.constexpr,
164-
BS: tl.constexpr,
165-
sync_space,
166155
):
167-
i_ns, i_qh = tl.program_id(0), tl.program_id(1)
156+
i_b, i_qh = tl.program_id(0), tl.program_id(1)
168157
i_kvh = i_qh // (q_head_num // k_head_num)
169158

170-
for i_bs in range(BS):
171-
i_s = i_ns * BS + i_bs
172-
173-
i_pos = -1
174-
kv_seq_len = i_s
175-
176-
for i in range(B):
177-
tmp_seq_len = tl.load(kv_seq_lens + i)
178-
if kv_seq_len >= tmp_seq_len and i_pos == -1:
179-
kv_seq_len -= tmp_seq_len
180-
elif i_pos == -1:
181-
i_pos = i
182-
183-
if i_pos != -1:
184-
kv_seq_len += 1
185-
186-
page_num = tl.cdiv(kv_seq_len, PAGE_SIZE)
187-
start_page_num = 0
188-
start_kv_len = 0
189-
if sliding_window_size != -1 and kv_seq_len > sliding_window_size:
190-
start_kv_len = (kv_seq_len - sliding_window_size).to(tl.int32)
191-
start_page_num = start_kv_len // PAGE_SIZE
192-
193-
cur_page_start = i_pos * MAX_BLOCKS
194-
offset_page = tl.arange(0, PAGE_SIZE)
195-
offset_d = tl.arange(0, D)
196-
Br: tl.constexpr = 1
197-
198-
sink = tl.load(sinks + i_qh)
199-
history_max = tl.zeros([Br], dtype=tl.float32) + sink
200-
l = tl.zeros([Br], dtype=tl.float32)
201-
acc = tl.zeros([Br, D], dtype=tl.float32)
202-
203-
offset_q = i_qh * D + offset_d
204-
offset_seq = (tl.arange(0, Br) + i_s) * D * q_head_num
205-
q = tl.load(query + offset_seq[:, None] + offset_q[None, :]).to(tl.float32)
206-
207-
for page_idx in range(start_page_num, page_num):
208-
block_idx = tl.load(block_tables + cur_page_start + page_idx)
209-
mask_page = ((page_idx * PAGE_SIZE + offset_page) < kv_seq_len) & (
210-
(page_idx * PAGE_SIZE + offset_page) >= start_kv_len
211-
)
212-
213-
offset_k = (
214-
block_idx * PAGE_SIZE * k_head_num * D
215-
+ offset_page[:, None] * k_head_num * D
216-
+ i_kvh * D
217-
+ offset_d[None, :]
218-
)
219-
k = tl.load(k_cache + offset_k, mask=mask_page[:, None]).to(tl.float32)
220-
v = tl.load(v_cache + offset_k, mask=mask_page[:, None]).to(tl.float32)
221-
222-
k = tl.trans(k, (1, 0))
223-
qk = tl.dot(q, k)
224-
qk = qk * scale
225-
qk = tl.where(mask_page[None, :], qk, float("-inf"))
226-
227-
new_e_max = tl.maximum(tl.max(qk, 1), history_max)
228-
re_scale = tl.exp(history_max - new_e_max)
229-
p_exp = tl.exp(qk - new_e_max[:, None])
230-
231-
# Online softmax update
232-
l = l * re_scale + tl.sum(p_exp, 1)
233-
acc = acc * re_scale[:, None] + tl.dot(p_exp.to(v.dtype), v)
234-
235-
# The purpose of this store is to insert synchronization within the loop.
236-
# Do not remove this store until triton solves the synchronization problem,
237-
# as doing so may lead to accuracy problem.
238-
tl.store(sync_space + tl.arange(0, Br), new_e_max)
239-
history_max = new_e_max
240-
241-
sink = tl.math.exp(sink - history_max)
242-
l = l + sink
243-
acc = acc / l[:, None]
244-
tl.store(
245-
attn_out + offset_seq[:, None] + offset_q[None, :],
246-
acc.to(attn_out.type.element_ty),
159+
q_end_offset = tl.load(cum_seq_lens + i_b)
160+
q_start_offset = 0
161+
q_start_offset = q_start_offset.to(q_end_offset.dtype)
162+
if i_b > 0:
163+
q_start_offset = tl.load(cum_seq_lens + i_b - 1)
164+
165+
Br: tl.constexpr = 16
166+
167+
for i_s in range(q_start_offset, q_end_offset, Br):
168+
kv_seq_len = tl.load(kv_seq_lens + i_b) + i_s - q_end_offset + 1
169+
170+
page_num = tl.cdiv(kv_seq_len + Br, PAGE_SIZE)
171+
page_num = min(page_num, MAX_BLOCKS)
172+
173+
kv_seq_len_block = kv_seq_len + tl.arange(0, Br)
174+
start_kv_len_block = tl.zeros([Br], dtype=tl.int32)
175+
176+
start_page_num = 0
177+
if sliding_window_size != -1:
178+
start_kv_len = max((kv_seq_len - sliding_window_size).to(tl.int32), 0)
179+
start_page_num = start_kv_len // PAGE_SIZE
180+
start_kv_len_block = max(
181+
(kv_seq_len_block - sliding_window_size).to(tl.int32), 0
182+
)
183+
184+
cur_page_start = i_b * MAX_BLOCKS
185+
offset_page = tl.arange(0, PAGE_SIZE)
186+
offset_d = tl.arange(0, D)
187+
188+
sink = tl.load(sinks + i_qh)
189+
history_max = tl.zeros([Br], dtype=tl.float32) + sink
190+
l = tl.zeros([Br], dtype=tl.float32)
191+
acc = tl.zeros([Br, D], dtype=tl.float32)
192+
193+
offset_q = i_qh * D + offset_d
194+
offset_seq = (tl.arange(0, Br) + i_s) * D * q_head_num
195+
mask_seq = (tl.arange(0, Br) + i_s) < q_end_offset
196+
q = tl.load(
197+
query + offset_seq[:, None] + offset_q[None, :], mask=mask_seq[:, None]
198+
).to(tl.float32)
199+
200+
for page_idx in range(start_page_num, page_num):
201+
block_idx = tl.load(block_tables + cur_page_start + page_idx)
202+
cur_offset_page = page_idx * PAGE_SIZE + offset_page
203+
mask_page = (cur_offset_page[None, :] < kv_seq_len_block[:, None]) & (
204+
cur_offset_page[None, :] >= start_kv_len_block[:, None]
205+
)
206+
207+
offset_k = (
208+
block_idx * PAGE_SIZE * k_head_num * D
209+
+ offset_page[:, None] * k_head_num * D
210+
+ i_kvh * D
211+
+ offset_d[None, :]
247212
)
213+
k = tl.load(k_cache + offset_k).to(tl.float32)
214+
v = tl.load(v_cache + offset_k).to(tl.float32)
215+
216+
k = tl.trans(k, (1, 0))
217+
qk = tl.dot(q, k)
218+
qk = qk * scale
219+
qk = tl.where(mask_page, qk, float("-inf"))
220+
221+
new_e_max = tl.maximum(tl.max(qk, 1), history_max)
222+
re_scale = tl.exp(history_max - new_e_max)
223+
p_exp = tl.exp(qk - new_e_max[:, None])
224+
225+
# Online softmax update
226+
l = l * re_scale + tl.sum(p_exp, 1)
227+
acc = acc * re_scale[:, None] + tl.dot(p_exp.to(v.dtype), v)
228+
229+
history_max = new_e_max
230+
231+
sink = tl.math.exp(sink - history_max)
232+
l = l + sink
233+
acc = acc / l[:, None]
234+
tl.store(
235+
attn_out + offset_seq[:, None] + offset_q[None, :],
236+
acc.to(attn_out.type.element_ty),
237+
mask=mask_seq[:, None],
238+
)
248239

249240

250241
def attention_sinks_prefill_triton(
251242
query,
252243
k_cache,
253244
v_cache,
254245
sinks,
246+
seq_lens,
255247
block_tables,
256248
context_lens,
257249
scale,
@@ -260,10 +252,6 @@ def attention_sinks_prefill_triton(
260252
k_head_num,
261253
):
262254
S = query.shape[0]
263-
kernel_num = get_device_properties()[0]
264-
BS = triton.cdiv(S, kernel_num)
265-
NS = triton.cdiv(S, BS)
266-
267255
D = query.shape[-1] // q_head_num
268256
PAGE_SIZE = k_cache.shape[1]
269257
v_head_dim = v_cache.shape[-1]
@@ -272,25 +260,18 @@ def attention_sinks_prefill_triton(
272260
dtype=query.dtype,
273261
device=query.device,
274262
)
275-
sync_space = torch.empty(
276-
(PAGE_SIZE,),
277-
dtype=torch.float32,
278-
device=query.device,
279-
)
280263

281-
if isinstance(context_lens, list):
282-
context_lens = torch.tensor(context_lens, device=query.device)
283-
else:
284-
context_lens = context_lens.to(query.device)
285-
B = context_lens.shape[0]
264+
cum_seq_lens = torch.cumsum(seq_lens, dim=0)
265+
B = seq_lens.shape[0]
286266

287-
grid = [NS, q_head_num]
267+
grid = [B, q_head_num]
288268
attention_sinks_prefill_kernel[grid](
289269
query,
290270
k_cache,
291271
v_cache,
292272
sinks,
293273
attn_output,
274+
cum_seq_lens,
294275
block_tables,
295276
context_lens,
296277
scale,
@@ -300,9 +281,6 @@ def attention_sinks_prefill_triton(
300281
D,
301282
PAGE_SIZE,
302283
block_tables.stride(0),
303-
B,
304-
BS,
305-
sync_space,
306284
)
307285

308286
return attn_output.reshape(-1, q_head_num * v_head_dim)

0 commit comments

Comments
 (0)