Skip to content

Commit 38bfc20

Browse files
committed
add test for custom op lightning_attention
Signed-off-by: ChenxiQ <chenxi.qian.cq@outlook.com>
1 parent 8f90f4d commit 38bfc20

File tree

2 files changed

+378
-0
lines changed

2 files changed

+378
-0
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import gc
2+
import math
3+
import copy
4+
import torch
5+
import torch_npu
6+
7+
# enable vllm-ascend custom ops
8+
from vllm_ascend.utils import enable_custom_op
9+
enable_custom_op()
10+
11+
12+
def build_decay(head_num):
13+
# return decay rate with shape (head_num)
14+
start = 2 ** (-(2 ** -(math.log2(head_num) - 3)))
15+
ratio = start
16+
return torch.tensor([start * ratio**i for i in range(head_num)])
17+
18+
19+
def lightning_attention_decode(q, k, v, kv_decay, kv_cache, dtype):
20+
kv_cur = torch.outer(k, v)
21+
kv_pre = kv_decay * kv_cache
22+
kv = kv_cur + kv_pre
23+
o = torch.matmul(q, kv)
24+
return o, kv
25+
26+
27+
def reference_lightning_attention_decode(query, key, value, slope_rate, kv_history, slot_ids, dtype):
28+
# in_tensors[0]: Query (batch, head, 1, d)
29+
# in_tensors[1]: Key (batch, head, 1, d)
30+
# in_tensors[2]: Value (batch, head, 1, d)
31+
# in_tensors[3]: Decay (head)
32+
# in_tensors[4]: KV Caches (batch, head, d, d)
33+
# in_tensors[5]: slot_ids (batch)
34+
batch_num, head_num, _, d = query.shape
35+
query = query.to(torch.float32)
36+
key = key.to(torch.float32)
37+
value = value.to(torch.float32)
38+
slope_rate = slope_rate.to(torch.float32)
39+
kv_caches = kv_history.clone().to(torch.float32)
40+
41+
# initialize O (batch, head * d)
42+
output = torch.zeros(batch_num, head_num * d, dtype=dtype)
43+
44+
for batchidx in range(batch_num):
45+
slot_id = slot_ids[batchidx]
46+
for headidx in range(head_num):
47+
q = query[batchidx, headidx, 0, :]
48+
k = key[batchidx, headidx, 0, :]
49+
v = value[batchidx, headidx, 0, :]
50+
kv_decay = math.exp(-slope_rate[headidx])
51+
kv_cache = kv_caches[slot_id, headidx, :, :]
52+
o, kv = lightning_attention_decode(q, k, v, kv_decay, kv_cache, dtype)
53+
output[batchidx, headidx*d:(headidx+1)*d] = o.to(dtype)
54+
kv_caches[slot_id, headidx, :, :] = kv
55+
56+
return output, kv_caches.to(dtype)
57+
58+
59+
def execute_lightning_attention_decode_case(self, q_batch_size, kv_cache_batch, head_num, head_dim,
60+
dtype=torch.float16):
61+
query_cpu = torch.randn(q_batch_size, head_num, 1, head_dim).to(dtype)
62+
key_cpu = torch.randn(q_batch_size, head_num, 1, head_dim).to(dtype)
63+
value_cpu = torch.randn(q_batch_size, head_num, 1, head_dim).to(dtype)
64+
slope_rate_cpu = build_decay(head_num).to(dtype)
65+
kv_history_cpu = torch.randn(kv_cache_batch, head_num, head_dim, head_dim).to(dtype)
66+
slot_ids_cpu = torch.arange(kv_cache_batch).to(torch.int32)[-q_batch_size:]
67+
68+
query_npu = copy.deepcopy(query_cpu).npu()
69+
key_npu = copy.deepcopy(key_cpu).npu()
70+
value_npu = copy.deepcopy(value_cpu).npu()
71+
slope_rate_npu = copy.deepcopy(slope_rate_cpu).npu()
72+
kv_history_npu = copy.deepcopy(kv_history_cpu).npu()
73+
slot_ids_npu = copy.deepcopy(slot_ids_cpu).npu()
74+
75+
76+
# calculate on npu
77+
attention_npu_out = torch.ops._C_ascend.npu_lightning_attention_decode(
78+
query_npu, key_npu, value_npu, kv_history_npu, slope_rate_npu, slot_ids_npu)
79+
80+
# calculate on cpu
81+
attention_cpu_out, kv_cache_cpu_out = reference_lightning_attention_decode(
82+
query_cpu, key_cpu, value_cpu, slope_rate_cpu, kv_history_cpu, slot_ids_cpu, dtype)
83+
84+
# compare result
85+
torch.testing.assert_close(attention_npu_out.cpu(),
86+
attention_cpu_out,
87+
atol=1e-9,
88+
rtol=1e-6)
89+
torch.testing.assert_close(kv_history_npu.cpu(),
90+
kv_cache_cpu_out,
91+
atol=1e-9,
92+
rtol=1e-6)
93+
94+
95+
@torch.inference_mode()
96+
def test_lightning_attention_decode_same_batch(self):
97+
q_batch_size = 256
98+
head_num = 8
99+
head_dim = 128
100+
execute_lightning_attention_decode_case(q_batch_size, q_batch_size, head_num, head_dim)
101+
gc.collect()
102+
torch.npu.empty_cache()
103+
torch.npu.reset_peak_memory_stats()
104+
105+
@torch.inference_mode()
106+
def test_lightning_attention_decode_different_batch(self):
107+
q_batch_size = 1
108+
kv_cache_batch = 256
109+
head_num = 8
110+
head_dim = 128
111+
execute_lightning_attention_decode_case(q_batch_size, kv_cache_batch, head_num, head_dim)
112+
gc.collect()
113+
torch.npu.empty_cache()
114+
torch.npu.reset_peak_memory_stats()
115+
116+
@torch.inference_mode()
117+
def test_lightning_attention_decode_fp32(self):
118+
q_batch_size = 100
119+
head_num = 16
120+
head_dim = 128
121+
execute_lightning_attention_decode_case(q_batch_size, q_batch_size, head_num, head_dim, torch.float32)
122+
gc.collect()
123+
torch.npu.empty_cache()
124+
torch.npu.reset_peak_memory_stats()
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
import gc
2+
import math
3+
import copy
4+
import torch
5+
import torch_npu
6+
7+
# enable vllm-ascend custom ops
8+
from vllm_ascend.utils import enable_custom_op
9+
enable_custom_op()
10+
11+
12+
def build_decay(head_num):
13+
# return decay rate with shape (head_num)
14+
start = 2 ** (-(2 ** -(math.log2(head_num) - 3)))
15+
ratio = start
16+
return torch.tensor([start * ratio**i for i in range(head_num)])
17+
18+
19+
def lightning_attention_prefill(qt, kt, vt, kvsum, diag_decay, q_decay, block_decay, k_decay, dtype):
20+
# O_intra = [(Q_t K_t^T) * M]V_t
21+
qt_kt = torch.matmul(qt, torch.transpose(kt, 0, 1))
22+
qt_kt_mask = torch.mul(qt_kt, diag_decay).to(dtype)
23+
o_intra = torch.matmul(qt_kt_mask.to(torch.float32), vt)
24+
25+
# O_inter = Λ Q_t (KV)
26+
o_inter = q_decay * torch.matmul(qt, kvsum.to(dtype).to(torch.float32))
27+
28+
# update KVsum
29+
# KVsum = λ^B KVsum + (λ^B Λ^-1 K_t)^T V_t
30+
kt = k_decay * kt
31+
kt = kt.to(dtype)
32+
kt_vt = torch.matmul(torch.transpose(kt, 0, 1).to(torch.float32), vt)
33+
kvsum = torch.add(block_decay * kvsum, kt_vt)
34+
35+
# O_t = O_intra + O_inter
36+
o_t = torch.add(o_intra, o_inter)
37+
38+
return o_t, kvsum
39+
40+
41+
def reference_lightning_attention(q, k, v, ed, block_size, kv_history, seq_len):
42+
dtype = q.dtype
43+
batch_num, head_num, n, d = q.shape
44+
if seq_len is None:
45+
seq_len = [n] * batch_num
46+
B = block_size
47+
T = n // B
48+
49+
# get Q, K, V, decay
50+
# in_tensors[0]: Query without tiling (batch, head, n, d)
51+
# in_tensors[1]: Key without tiing (batch, head, n, d)
52+
# in_tensors[2]: Value without tiling (batch, head, n, d)
53+
# in_tensors[3]: Decay (head)
54+
query = q.reshape(batch_num, head_num, T, B, d).to(torch.float32) # (batch, head, T, B, d)
55+
key = k.reshape(batch_num, head_num, T, B, d).to(torch.float32) # (batch, head, T, B, d)
56+
value = v.reshape(batch_num, head_num, T, B, d).to(torch.float32) # (batch, head, T, B, d)
57+
decay = ed.to(torch.float32) # (head)
58+
59+
# initialize O, KVsum
60+
output = torch.zeros(batch_num, head_num, T, B, d, dtype=dtype) # (batch, head, T, B, d)
61+
if kv_history is None:
62+
kvsums = torch.zeros(batch_num, head_num, d, d, dtype=torch.float32)
63+
else:
64+
kvsums = kv_history.clone().to(torch.float32) # (batch, head, d, d)
65+
66+
for batchidx in range(batch_num):
67+
for headidx in range(head_num):
68+
kvsum = kvsums[batchidx, headidx, :, :]
69+
70+
# diag_decay: M with shape (B, B)
71+
# q_decay: Λ with shape (B, 1)
72+
# block_decay: λ^B with shape (1)
73+
# k_decay: λ^B Λ^-1 with shape (B, 1)
74+
s = decay[headidx]
75+
i = torch.arange(B).view(B, 1)
76+
j = torch.arange(B)
77+
index = i - j
78+
diag_decay = torch.exp(s * torch.where(index>=0, -index, float('-inf')))
79+
q_decay = torch.exp(-s * (j + 1)).reshape(B, 1)
80+
block_decay = math.exp(-s * B)
81+
k_decay = torch.exp(-s * (B - i - 1))
82+
83+
block_count = (seq_len[batchidx] + B - 1) // B
84+
tail_block_size = seq_len[batchidx] % B
85+
for t in range(block_count):
86+
qt = query[batchidx, headidx, t, :, :]
87+
kt = key[batchidx, headidx, t, :, :]
88+
vt = value[batchidx, headidx, t, :, :]
89+
if tail_block_size != 0 and t + 1 == block_count:
90+
e = tail_block_size - i - 1
91+
e[tail_block_size:] = 0
92+
k_decay = torch.exp(-s * e)
93+
block_decay = math.exp(-s * tail_block_size)
94+
ot, kvsum = lightning_attention_prefill(
95+
qt, kt, vt, kvsum, diag_decay, q_decay, block_decay, k_decay, dtype)
96+
output[batchidx, headidx, t, :, :] = ot.to(dtype)
97+
98+
kvsums[batchidx, headidx, :, :] = kvsum
99+
100+
output = output.reshape(batch_num, head_num, n, d) # (batch, head, n, d)
101+
kvsums = kvsums.to(dtype)
102+
return [output, kvsums]
103+
104+
105+
def execute_lightning_attention_prefill_case(self, batch_size, head_num, max_seq_len, head_dim, block_size,
106+
has_kv_history=False, actual_seq_len=None, dtype=torch.float16,
107+
slope_rate=None):
108+
109+
base = 0.1
110+
query_cpu = base * torch.randn(batch_size, head_num, max_seq_len, head_dim).to(dtype)
111+
key_cpu = base * torch.randn(batch_size, head_num, max_seq_len, head_dim).to(dtype)
112+
value_cpu = base * torch.randn(batch_size, head_num, max_seq_len, head_dim).to(dtype)
113+
if actual_seq_len:
114+
for b in range(batch_size):
115+
if actual_seq_len[b] < max_seq_len:
116+
query_cpu[b,:, actual_seq_len[b]:,:] = 0
117+
key_cpu[b,:, actual_seq_len[b]:,:] = 0
118+
value_cpu[b,:, actual_seq_len[b]:,:] = 0
119+
120+
slope_rate_cpu = slope_rate
121+
if slope_rate_cpu is None:
122+
slope_rate_cpu = build_decay(head_num).to(dtype)
123+
124+
query_npu = copy.deepcopy(query_cpu).npu()
125+
key_npu = copy.deepcopy(key_cpu).npu()
126+
value_npu = copy.deepcopy(value_cpu).npu()
127+
slope_rate_npu = copy.deepcopy(slope_rate_cpu).npu()
128+
kv_history_cpu = None
129+
kv_history_npu = None
130+
if has_kv_history:
131+
kv_history_cpu = base * torch.randn(batch_size, head_num, head_dim, head_dim).to(dtype)
132+
kv_history_npu = copy.deepcopy(kv_history_cpu).npu()
133+
134+
# calculate on npu
135+
attention_npu_out, kv_cache_npu_out = torch.ops._C_ascend.npu_lightning_attention_prefill(
136+
query_npu, key_npu, value_npu, slope_rate_npu, block_size, kv_history_npu, actual_seq_len)
137+
138+
# calculate on cpu
139+
attention_cpu_out, kv_cache_cpu_out = reference_lightning_attention(
140+
query_cpu, key_cpu, value_cpu, slope_rate_cpu, block_size, kv_history_cpu, actual_seq_len)
141+
142+
if actual_seq_len:
143+
for b in range(batch_size):
144+
if actual_seq_len[b] < max_seq_len:
145+
# npu default value may not be 0
146+
attention_npu_out[b,:, actual_seq_len[b]:,:] = 0
147+
148+
# compare result
149+
torch.testing.assert_close(attention_npu_out.cpu(),
150+
attention_cpu_out,
151+
atol=1e-9,
152+
rtol=1e-6)
153+
torch.testing.assert_close(kv_cache_npu_out.cpu(),
154+
kv_cache_cpu_out,
155+
atol=1e-9,
156+
rtol=1e-6)
157+
158+
159+
@torch.inference_mode()
160+
def test_lightning_attention_prefill_pad(self):
161+
batch_size = 1
162+
head_num = 4
163+
max_seq_len = 8192
164+
head_dim = 128
165+
block_size = 128
166+
execute_lightning_attention_prefill_case(batch_size, head_num, max_seq_len, head_dim, block_size)
167+
gc.collect()
168+
torch.npu.empty_cache()
169+
torch.npu.reset_peak_memory_stats()
170+
171+
@torch.inference_mode()
172+
def test_lightning_attention_prefill_unpad_1(self):
173+
batch_size = 1
174+
head_num = 8
175+
max_seq_len = 16
176+
block_size = 16
177+
head_dim = 128
178+
actual_seq_len = [5]
179+
execute_lightning_attention_prefill_case(batch_size, head_num, max_seq_len, head_dim, block_size, False,
180+
actual_seq_len)
181+
gc.collect()
182+
torch.npu.empty_cache()
183+
torch.npu.reset_peak_memory_stats()
184+
def test_lightning_attention_prefill_unpad_2(self):
185+
batch_size = 4
186+
head_num = 8
187+
max_seq_len = 2048
188+
block_size = 128
189+
head_dim = 128
190+
actual_seq_len = [np.random.randint(1, max_seq_len / block_size + 1) * block_size
191+
for _ in range(batch_size)]
192+
execute_lightning_attention_prefill_case(batch_size, head_num, max_seq_len, head_dim, block_size,
193+
False, actual_seq_len)
194+
gc.collect()
195+
torch.npu.empty_cache()
196+
torch.npu.reset_peak_memory_stats()
197+
198+
@torch.inference_mode()
199+
def test_lightning_attention_prefill_unpad_3(self):
200+
batch_size = 3
201+
head_num = 8
202+
max_seq_len = 384
203+
block_size = 128
204+
head_dim = 128
205+
actual_seq_len = [351, 129, 384]
206+
execute_lightning_attention_prefill_case(batch_size, head_num, max_seq_len, head_dim, block_size, False,
207+
actual_seq_len)
208+
gc.collect()
209+
torch.npu.empty_cache()
210+
torch.npu.reset_peak_memory_stats()
211+
212+
@torch.inference_mode()
213+
def test_lightning_attention_prefill_unpad_4(self):
214+
batch_size = 1
215+
head_num = 4
216+
max_seq_len = 256
217+
block_size = 256
218+
head_dim = 128
219+
actual_seq_len = [5]
220+
slope_rate = torch.tensor([0.9170, 0.8409, 0.7711, 0.7071], dtype=torch.float16)
221+
execute_lightning_attention_prefill_case(batch_size, head_num, max_seq_len, head_dim, block_size, False,
222+
actual_seq_len, torch.float16, slope_rate)
223+
gc.collect()
224+
torch.npu.empty_cache()
225+
torch.npu.reset_peak_memory_stats()
226+
227+
@torch.inference_mode()
228+
def test_lightning_attention_prefill_with_kv_history(self):
229+
batch_size = 4
230+
head_num = 8
231+
max_seq_len = 1024
232+
head_dim = 128
233+
block_size = 128
234+
actual_seq_len = [np.random.randint(1, max_seq_len / block_size + 1) * block_size
235+
for _ in range(batch_size)]
236+
execute_lightning_attention_prefill_case(batch_size, head_num, max_seq_len, head_dim, block_size,
237+
True, actual_seq_len)
238+
gc.collect()
239+
torch.npu.empty_cache()
240+
torch.npu.reset_peak_memory_stats()
241+
242+
@torch.inference_mode()
243+
def test_lightning_attention_prefill_fp32(self):
244+
batch_size = 1
245+
head_num = 16
246+
max_seq_len = 256
247+
head_dim = 128
248+
block_size = 128
249+
actual_seq_len = [130]
250+
execute_lightning_attention_prefill_case(batch_size, head_num, max_seq_len, head_dim, block_size,
251+
True, actual_seq_len, torch.float32)
252+
gc.collect()
253+
torch.npu.empty_cache()
254+
torch.npu.reset_peak_memory_stats()

0 commit comments

Comments
 (0)