Skip to content

Commit c726cd8

Browse files
Merge pull request #373 from cen121212/2-11-main
2 parents 6476b2d + 7be1e33 commit c726cd8

File tree

2 files changed

+260
-0
lines changed

2 files changed

+260
-0
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
@triton.jit
7+
def fused_rope_qk_mqa_kernel(
8+
query_ptr, # [T, Hq, D]
9+
key_ptr, # [T, Hk, D]
10+
cos_sin_ptr, # [max_pos, 2*D_ROPE]
11+
out_q_ptr,
12+
out_k_ptr,
13+
stride_qt, stride_qh, stride_qd,
14+
stride_kt, stride_kh, stride_kd,
15+
stride_ct, stride_cd,
16+
stride_oqt, stride_oqh, stride_oqd,
17+
stride_okt, stride_okh, stride_okd,
18+
Hk: tl.constexpr,
19+
D_HEAD: tl.constexpr,
20+
D_ROPE: tl.constexpr,
21+
IS_NEOX_STYLE: tl.constexpr,
22+
):
23+
pid_t = tl.program_id(0)
24+
pid_h = tl.program_id(1)
25+
26+
# MQA: key head broadcast
27+
kh = pid_h % Hk
28+
29+
# -------- rotary indices
30+
d = tl.arange(0, D_ROPE // 2)
31+
if IS_NEOX_STYLE:
32+
idx_even = d
33+
idx_odd = d + D_ROPE//2
34+
else:
35+
idx_even = d * 2
36+
idx_odd = d * 2 + 1
37+
38+
# cos / sin
39+
cos = tl.load(cos_sin_ptr + pid_t* stride_ct + d * stride_cd)
40+
sin = tl.load(
41+
cos_sin_ptr + pid_t* stride_ct + (d + D_ROPE // 2) * stride_cd
42+
)
43+
44+
45+
# ================= Q =================
46+
q_base = query_ptr + pid_t * stride_qt + pid_h * stride_qh
47+
48+
q1 = tl.load(q_base + idx_even * stride_qd)
49+
q2 = tl.load(q_base + idx_odd * stride_qd)
50+
51+
q_out1 = (q1 * cos) - (q2 * sin)
52+
q_out2 = (q1 * sin) + (q2 * cos)
53+
54+
oq_base = out_q_ptr + pid_t * stride_oqt + pid_h * stride_oqh
55+
tl.store(oq_base + idx_even * stride_oqd, q_out1)
56+
tl.store(oq_base + idx_odd * stride_oqd, q_out2)
57+
58+
# ================= K =================
59+
k_base = key_ptr + pid_t * stride_kt + kh * stride_kh
60+
k1 = tl.load(k_base + idx_even * stride_kd)
61+
k2 = tl.load(k_base + idx_odd * stride_kd)
62+
63+
k_out1 = (k1 * cos) - (k2 * sin)
64+
k_out2 = (k1 * sin) + (k2 * cos)
65+
66+
ok_base = out_k_ptr + pid_t * stride_okt + kh * stride_okh
67+
tl.store(ok_base + idx_even * stride_okd, k_out1)
68+
tl.store(ok_base + idx_odd * stride_okd, k_out2)
69+
70+
# ================= pass-through(编译期裁剪) =================
71+
if D_HEAD > D_ROPE:
72+
dp = tl.arange(0, D_HEAD - D_ROPE)
73+
tl.store(
74+
oq_base + (dp + D_ROPE) * stride_oqd,
75+
tl.load(q_base + (dp + D_ROPE) * stride_qd),
76+
)
77+
tl.store(
78+
ok_base + (dp + D_ROPE) * stride_okd,
79+
tl.load(k_base + (dp + D_ROPE) * stride_kd),
80+
)
81+
82+
83+
def fused_rope_qk_mqa(
84+
query, # [T, Hq, D]
85+
key, # [T, Hk, D]
86+
cos_sin, # [T, D]
87+
rotary_dim,
88+
is_neox_style
89+
):
90+
T, Hq, D = query.shape
91+
_, Hk, _ = key.shape
92+
93+
out_q = torch.empty_like(query)
94+
out_k = torch.empty_like(key)
95+
96+
grid = (T, Hq)
97+
98+
fused_rope_qk_mqa_kernel[grid](
99+
query,
100+
key,
101+
cos_sin,
102+
out_q,
103+
out_k,
104+
query.stride(0),
105+
query.stride(1),
106+
query.stride(2),
107+
key.stride(0),
108+
key.stride(1),
109+
key.stride(2),
110+
cos_sin.stride(0),
111+
cos_sin.stride(1),
112+
out_q.stride(0),
113+
out_q.stride(1),
114+
out_q.stride(2),
115+
out_k.stride(0),
116+
out_k.stride(1),
117+
out_k.stride(2),
118+
Hk=Hk,
119+
D_HEAD=D,
120+
D_ROPE=rotary_dim,
121+
IS_NEOX_STYLE=is_neox_style
122+
)
123+
124+
return out_q, out_k
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
@triton.jit
7+
def fused_split_qk_norm_kernel(
8+
fused_ptr, # [B, total_hidden]
9+
q_lora_ptr, # [B, q_lora]
10+
k_nope_ptr, # [B, kv_lora]
11+
k_pe_ptr, # [B, qk_rope]
12+
q_rms_w_ptr,
13+
q_rms_b_ptr,
14+
k_rms_w_ptr,
15+
k_rms_b_ptr,
16+
total_hidden: tl.constexpr,
17+
q_lora_rank: tl.constexpr,
18+
kv_lora_rank: tl.constexpr,
19+
qk_rope_dim: tl.constexpr,
20+
eps: tl.constexpr,
21+
Q_HAS_BIAS: tl.constexpr,
22+
K_HAS_BIAS: tl.constexpr,
23+
):
24+
pid = tl.program_id(0)
25+
26+
base = pid * total_hidden
27+
28+
# =====================================================
29+
# Q LORA (RMSNorm)
30+
# =====================================================
31+
q_offs = tl.arange(0, q_lora_rank)
32+
q = tl.load(
33+
fused_ptr + base + q_offs,
34+
mask=q_offs < q_lora_rank,
35+
other=0.0,
36+
).to(tl.float32)
37+
38+
q_var = tl.sum(q * q, axis=0) / q_lora_rank
39+
q_rstd = tl.rsqrt(q_var + eps)
40+
41+
qw = tl.load(q_rms_w_ptr + q_offs, mask=q_offs < q_lora_rank)
42+
q = q * q_rstd * qw
43+
44+
if Q_HAS_BIAS:
45+
qb = tl.load(q_rms_b_ptr + q_offs, mask=q_offs < q_lora_rank)
46+
q += qb
47+
48+
tl.store(q_lora_ptr + pid * q_lora_rank + q_offs, q, mask=q_offs < q_lora_rank)
49+
50+
# =====================================================
51+
# K NOPE (RMSNorm)
52+
# =====================================================
53+
k_base = base + q_lora_rank
54+
k_offs = tl.arange(0, kv_lora_rank)
55+
56+
k = tl.load(
57+
fused_ptr + k_base + k_offs,
58+
mask=k_offs < kv_lora_rank,
59+
other=0.0,
60+
).to(tl.float32)
61+
62+
k_var = tl.sum(k * k, axis=0) / kv_lora_rank
63+
k_rstd = tl.rsqrt(k_var + eps)
64+
65+
kw = tl.load(k_rms_w_ptr + k_offs, mask=k_offs < kv_lora_rank)
66+
k = k * k_rstd * kw
67+
68+
if K_HAS_BIAS:
69+
kb = tl.load(k_rms_b_ptr + k_offs, mask=k_offs < kv_lora_rank)
70+
k += kb
71+
72+
tl.store(k_nope_ptr + pid * kv_lora_rank + k_offs, k, mask=k_offs < kv_lora_rank)
73+
74+
# =====================================================
75+
# K PE (no norm, direct copy)
76+
# =====================================================
77+
pe_offs = tl.arange(0, qk_rope_dim)
78+
pe_base = k_base + kv_lora_rank
79+
80+
k_pe = tl.load(
81+
fused_ptr + pe_base + pe_offs,
82+
mask=pe_offs < qk_rope_dim,
83+
)
84+
85+
tl.store(
86+
k_pe_ptr + pid * qk_rope_dim + pe_offs,
87+
k_pe,
88+
mask=pe_offs < qk_rope_dim,
89+
)
90+
91+
def fused_split_qk_norm(
92+
fused_qkv_a_proj_out,
93+
q_a_layernorm,
94+
kv_a_layernorm,
95+
q_lora_rank,
96+
kv_lora_rank,
97+
qk_rope_dim,
98+
eps=1e-6,
99+
):
100+
B, total_hidden = fused_qkv_a_proj_out.shape
101+
device = fused_qkv_a_proj_out.device
102+
dtype = fused_qkv_a_proj_out.dtype
103+
104+
q_lora = torch.empty(
105+
(B, q_lora_rank), device=device, dtype=dtype
106+
)
107+
k_nope = torch.empty(
108+
(B, kv_lora_rank), device=device, dtype=dtype
109+
)
110+
k_pe = torch.empty(
111+
(B, qk_rope_dim), device=device, dtype=dtype
112+
)
113+
114+
fused_split_qk_norm_kernel[(B,)](
115+
fused_qkv_a_proj_out,
116+
q_lora,
117+
k_nope,
118+
k_pe,
119+
q_a_layernorm.weight,
120+
q_a_layernorm.bias,
121+
kv_a_layernorm.weight,
122+
kv_a_layernorm.bias,
123+
total_hidden,
124+
q_lora_rank,
125+
kv_lora_rank,
126+
qk_rope_dim,
127+
eps,
128+
q_a_layernorm.bias is not None,
129+
kv_a_layernorm.bias is not None,
130+
)
131+
132+
# 还原原始形态(unsqueeze(1))
133+
k_nope = k_nope.unsqueeze(1)
134+
k_pe = k_pe.unsqueeze(1)
135+
136+
return q_lora, k_nope, k_pe

0 commit comments

Comments
 (0)