Skip to content

Commit b148bb4

Browse files
committed
Fix Gemma3 head norm (RMS)
1 parent d471d44 commit b148bb4

File tree

11 files changed

+129
-59
lines changed

11 files changed

+129
-59
lines changed

exllamav2/architecture.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ class Params:
209209

210210
# Layer norm type
211211
norm = "rmsnorm"
212+
headnorm = "layernorm"
212213

213214
# RoPE style
214215
rope_style = RopeStyle.NEOX
@@ -520,6 +521,7 @@ class Params:
520521
self.lm.default_sliding_window_pattern = 6
521522
self.lm.default_rope_theta = 1e6
522523
self.lm.pos_id_index = 1
524+
self.lm.headnorm = "rmsnorm"
523525

524526
self.vt_prefix = "vision_tower.vision_model."
525527
self.vt.keys.update({

exllamav2/attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def load(self, device_context: bool = True):
301301
norm_weight,
302302
norm_bias,
303303
is_rms,
304+
self.archparams.headnorm == "rmsnorm",
304305
eps,
305306
self.q_proj.q_handle,
306307
self.k_proj.q_handle,

exllamav2/exllamav2_ext/cuda/head_norm.cu

Lines changed: 89 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#include "util.cuh"
33
#include "compat.cuh"
44

5-
#define MAX_HEAD_DIM 128
5+
#define MAX_HEAD_DIM 256
66
#define WARP_SIZE 32
77
#define MAX_WARPS (MAX_HEAD_DIM / WARP_SIZE)
88

@@ -16,7 +16,8 @@ __global__ void head_norm_kernel
1616
const float r_dim,
1717
const int rows,
1818
const int num_heads,
19-
const int head_dim
19+
const int head_dim,
20+
const bool rms
2021
)
2122
{
2223
int warp_id = threadIdx.x / WARP_SIZE;
@@ -37,62 +38,106 @@ __global__ void head_norm_kernel
3738
float itemf[2];
3839
float sum = 0.0f;
3940

40-
half2 h01 = ((half2*)x_ptr)[t];
41-
float f0 = __half2float(__low2half(h01));
42-
float f1 = __half2float(__high2half(h01));
43-
f0 = fmaxf(-65504.0f, fminf(f0, 65504.0f));
44-
f1 = fmaxf(-65504.0f, fminf(f1, 65504.0f));
45-
itemf[0] = f0;
46-
itemf[1] = f1;
47-
sum += f0;
48-
sum += f1;
41+
// RMS Norm
42+
43+
if (rms)
44+
{
45+
half2 h01 = ((half2*)x_ptr)[t];
46+
float f0 = __half2float(__low2half(h01));
47+
float f1 = __half2float(__high2half(h01));
48+
f0 = fmaxf(-65504.0f, fminf(f0, 65504.0f));
49+
f1 = fmaxf(-65504.0f, fminf(f1, 65504.0f));
50+
itemf[0] = f0;
51+
itemf[1] = f1;
52+
sum = fma(f0, f0, sum);
53+
sum = fma(f1, f1, sum);
54+
55+
// Shuffle to sum across lanes
56+
57+
for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset);
58+
if (lane_id == 0) sums[warp_id] = sum;
59+
__syncthreads();
60+
61+
// Sum of partial sums
62+
63+
sum = 0.0f;
64+
for(int i = 0; i < num_warps; ++i) sum += sums[i];
65+
66+
// Get 1/sqrt(variance)
67+
68+
float rsvar = rsqrtf(sum * r_dim + epsilon);
69+
70+
// Normalize x, scaling by w
71+
72+
half2 w01 = w_ptr2[t];
73+
float n0 = itemf[0] * __half2float(__low2half(w01)) * rsvar;
74+
float n1 = itemf[1] * __half2float(__high2half(w01)) * rsvar;
75+
half2 nh = __halves2half2(__float2half_rn(n0), __float2half_rn(n1));
76+
if (b) nh = __hadd2(nh, b_ptr2[t]); // Optional bias
77+
y_ptr2[t] = nh;
78+
}
79+
80+
// LayerNorm
81+
82+
else
83+
{
84+
half2 h01 = ((half2*)x_ptr)[t];
85+
float f0 = __half2float(__low2half(h01));
86+
float f1 = __half2float(__high2half(h01));
87+
f0 = fmaxf(-65504.0f, fminf(f0, 65504.0f));
88+
f1 = fmaxf(-65504.0f, fminf(f1, 65504.0f));
89+
itemf[0] = f0;
90+
itemf[1] = f1;
91+
sum += f0;
92+
sum += f1;
4993

50-
// Shuffle to sum across lanes
94+
// Shuffle to sum across lanes
5195

52-
for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset);
53-
if (lane_id == 0) sums[warp_id] = sum;
54-
__syncthreads();
96+
for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset);
97+
if (lane_id == 0) sums[warp_id] = sum;
98+
__syncthreads();
5599

56-
// Sum of partial sums
100+
// Sum of partial sums
57101

58-
sum = 0.0f;
59-
for(int i = 0; i < num_warps; ++i) sum += sums[i];
102+
sum = 0.0f;
103+
for(int i = 0; i < num_warps; ++i) sum += sums[i];
60104

61-
// Compute mean
105+
// Compute mean
62106

63-
float mean = sum * r_dim;
107+
float mean = sum * r_dim;
64108

65-
// Compute square of distance to mean
109+
// Compute square of distance to mean
66110

67-
sum = 0.0f;
68-
itemf[0] -= mean;
69-
itemf[1] -= mean;
70-
sum = fma(itemf[0], itemf[0], sum);
71-
sum = fma(itemf[1], itemf[1], sum);
111+
sum = 0.0f;
112+
itemf[0] -= mean;
113+
itemf[1] -= mean;
114+
sum = fma(itemf[0], itemf[0], sum);
115+
sum = fma(itemf[1], itemf[1], sum);
72116

73-
// Shuffle to sum across lanes
117+
// Shuffle to sum across lanes
74118

75-
for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset);
76-
if (lane_id == 0) sums[warp_id] = sum;
77-
__syncthreads();
119+
for(int offset = warpSize / 2; offset > 0; offset /= 2) sum += __shfl_xor_sync(0xffffffff, sum, offset);
120+
if (lane_id == 0) sums[warp_id] = sum;
121+
__syncthreads();
78122

79-
// Sum of partial sums
123+
// Sum of partial sums
80124

81-
sum = 0.0f;
82-
for(int i = 0; i < num_warps; ++i) sum += sums[i];
125+
sum = 0.0f;
126+
for(int i = 0; i < num_warps; ++i) sum += sums[i];
83127

84-
// Get 1/sqrt(variance)
128+
// Get 1/sqrt(variance)
85129

86-
float rsvar = rsqrtf(sum * r_dim + epsilon);
130+
float rsvar = rsqrtf(sum * r_dim + epsilon);
87131

88-
// Normalize x, scaling by w
132+
// Normalize x, scaling by w
89133

90-
half2 w01 = w_ptr2[t];
91-
float n0 = itemf[0] * __half2float(__low2half(w01)) * rsvar;
92-
float n1 = itemf[1] * __half2float(__high2half(w01)) * rsvar;
93-
half2 nh = __halves2half2(__float2half_rn(n0), __float2half_rn(n1));
94-
if (b) nh = __hadd2(nh, b_ptr2[t]); // Optional bias
95-
y_ptr2[t] = nh;
134+
half2 w01 = w_ptr2[t];
135+
float n0 = itemf[0] * __half2float(__low2half(w01)) * rsvar;
136+
float n1 = itemf[1] * __half2float(__high2half(w01)) * rsvar;
137+
half2 nh = __halves2half2(__float2half_rn(n0), __float2half_rn(n1));
138+
if (b) nh = __hadd2(nh, b_ptr2[t]); // Optional bias
139+
y_ptr2[t] = nh;
140+
}
96141
}
97142

98143
void head_norm_cuda
@@ -103,6 +148,7 @@ void head_norm_cuda
103148
const half* b,
104149
half* y,
105150
const float epsilon,
151+
bool rms,
106152
const int rows,
107153
const int num_heads,
108154
const int head_dim,
@@ -117,7 +163,7 @@ void head_norm_cuda
117163

118164
float r_dim = 1.0f / (float) head_dim;
119165

120-
head_norm_kernel<<<gridDim, blockDim, 0, stream>>>(x, w, b, y, epsilon, r_dim, rows, num_heads, head_dim);
166+
head_norm_kernel<<<gridDim, blockDim, 0, stream>>>(x, w, b, y, epsilon, r_dim, rows, num_heads, head_dim, rms);
121167
if (graph) graph->attach_label(stream, label, 0);
122168
}
123169

exllamav2/exllamav2_ext/cuda/head_norm.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ void head_norm_cuda
1515
const half* b,
1616
half* y,
1717
const float epsilon,
18+
bool rms,
1819
const int rows,
1920
const int num_heads,
2021
const int head_dim,

exllamav2/exllamav2_ext/cuda/q_attn.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ QAttn::QAttn
8686
half* _layernorm,
8787
half* _layernorm_bias,
8888
bool _layernorm_is_rms,
89+
bool _headnorm_is_rms,
8990
float _norm_epsilon,
9091
QMatrix* _q_proj,
9192
QMatrix* _k_proj,
@@ -115,6 +116,7 @@ QAttn::QAttn
115116
layernorm(_layernorm),
116117
layernorm_bias(_layernorm_bias),
117118
layernorm_is_rms(_layernorm_is_rms),
119+
headnorm_is_rms(_headnorm_is_rms),
118120
norm_epsilon(_norm_epsilon),
119121
q_proj(_q_proj),
120122
k_proj(_k_proj),
@@ -281,10 +283,10 @@ void QAttn::forward_cuda_1_run
281283
apply_loras_cuda(stream, cublas_handle, v_proj_lora, loras, v_proj, norm_state, temp_v, lora_temp, q_len * batch_size);
282284

283285
if (q_norm)
284-
head_norm_cuda(stream, temp_q, q_norm, NULL, temp_q, norm_epsilon, q_len * batch_size, num_heads, head_dim, graph, KernelLabels::Q_NORM);
286+
head_norm_cuda(stream, temp_q, q_norm, NULL, temp_q, norm_epsilon, headnorm_is_rms, q_len * batch_size, num_heads, head_dim, graph, KernelLabels::Q_NORM);
285287

286288
if (k_norm)
287-
head_norm_cuda(stream, temp_k, k_norm, NULL, temp_k, norm_epsilon, q_len * batch_size, num_kv_heads, head_dim, graph, KernelLabels::K_NORM);
289+
head_norm_cuda(stream, temp_k, k_norm, NULL, temp_k, norm_epsilon, headnorm_is_rms, q_len * batch_size, num_kv_heads, head_dim, graph, KernelLabels::K_NORM);
288290

289291
// rope_cuda(stream, temp_q, sin, cos, batch_size, q_len * num_heads, head_dim, num_heads, past_len, past_lens);
290292
// rope_cuda(stream, temp_k, sin, cos, batch_size, q_len * num_kv_heads, head_dim, num_kv_heads, past_len, past_lens);

exllamav2/exllamav2_ext/cuda/q_attn.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public:
4444
half* post_layernorm;
4545
half* post_layernorm_bias;
4646
bool layernorm_is_rms;
47+
bool headnorm_is_rms;
4748
float norm_epsilon;
4849

4950
half* q_norm;
@@ -86,6 +87,7 @@ public:
8687
half* _layernorm,
8788
half* _layernorm_bias,
8889
bool _layernorm_is_rms,
90+
bool _headnorm_is_rms,
8991
float _norm_epsilon,
9092
QMatrix* _q_proj,
9193
QMatrix* _k_proj,

exllamav2/exllamav2_ext/ext_norm.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ void head_norm
162162
torch::Tensor w,
163163
torch::Tensor b,
164164
torch::Tensor y,
165-
float epsilon
165+
float epsilon,
166+
bool rms
166167
)
167168
{
168169
TORCH_CHECK_DTYPE(x, kHalf);
@@ -191,6 +192,7 @@ void head_norm
191192
b.device().is_meta() ? NULL : (half*) b.data_ptr(),
192193
(half*) y.data_ptr(),
193194
epsilon,
195+
rms,
194196
rows,
195197
num_heads,
196198
head_dim
@@ -202,8 +204,9 @@ void head_norm_
202204
torch::Tensor x,
203205
torch::Tensor w,
204206
torch::Tensor b,
205-
float epsilon
207+
float epsilon,
208+
bool rms
206209
)
207210
{
208-
head_norm(x, w, b, x, epsilon);
211+
head_norm(x, w, b, x, epsilon, rms);
209212
}

exllamav2/exllamav2_ext/ext_norm.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,17 @@ void head_norm
4646
torch::Tensor w,
4747
torch::Tensor b,
4848
torch::Tensor y,
49-
float epsilon
49+
float epsilon,
50+
bool rms
5051
);
5152

5253
void head_norm_
5354
(
5455
torch::Tensor x,
5556
torch::Tensor w,
5657
torch::Tensor b,
57-
float epsilon
58+
float epsilon,
59+
bool rms
5860
);
5961

6062

exllamav2/exllamav2_ext/ext_qattn.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ uintptr_t make_q_attn
2626
torch::Tensor layernorm,
2727
torch::Tensor layernorm_bias,
2828
bool layernorm_is_rms,
29+
bool headnorm_is_rms,
2930
float norm_epsilon,
3031
uintptr_t q_q_proj,
3132
uintptr_t q_k_proj,
@@ -71,6 +72,7 @@ uintptr_t make_q_attn
7172
layernorm.is_meta() ? NULL : (half*) layernorm.data_ptr(),
7273
layernorm_bias.is_meta() ? NULL : (half*) layernorm_bias.data_ptr(),
7374
layernorm_is_rms,
75+
headnorm_is_rms,
7476
norm_epsilon,
7577
qm_q_proj,
7678
qm_k_proj,

exllamav2/exllamav2_ext/ext_qattn.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ uintptr_t make_q_attn
44
torch::Tensor layernorm,
55
torch::Tensor layernorm_bias,
66
bool layernorm_is_rms,
7+
bool headnorm_is_rms,
78
float norm_epsilon,
89
uintptr_t q_q_proj,
910
uintptr_t q_k_proj,

0 commit comments

Comments
 (0)