Skip to content

Commit ace5f16

Browse files
authored
bugfix: fix layernorm & rmsnorm f16 overflow (#335)
* [Test] Add f16 overflow testcase in layernorm and rmsnorm * [Fix] call correct `rms_norm_f16x8_f32` and fix epsilon position
1 parent 92feb71 commit ace5f16

File tree

4 files changed

+37
-12
lines changed

4 files changed

+37
-12
lines changed

kernels/layer-norm/layer_norm.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ __global__ void layer_norm_f16_f16_kernel(half *x, half *y, float g, float b,
197197
half variance = (value - s_mean) * (value - s_mean);
198198
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
199199
if (tid == 0)
200-
s_variance = hrsqrt(variance / (K_ + epsilon));
200+
s_variance = hrsqrt(variance / K_ + epsilon);
201201
// wait for s_variance in shared memory to be ready for all threads
202202
__syncthreads();
203203
if (idx < N * K) {
@@ -232,7 +232,7 @@ __global__ void layer_norm_f16x2_f16_kernel(half *x, half *y, float g, float b,
232232
half variance = reg_x_hat.x * reg_x_hat.x + reg_x_hat.y * reg_x_hat.y;
233233
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
234234
if (tid == 0)
235-
s_variance = hrsqrt(variance / (K_ + epsilon));
235+
s_variance = hrsqrt(variance / K_ + epsilon);
236236
// wait for s_variance in shared memory to be ready for all threads
237237
__syncthreads();
238238
if (idx < N * K) {
@@ -300,7 +300,7 @@ __global__ void layer_norm_f16x8_f16_kernel(half *x, half *y, float g, float b,
300300

301301
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
302302
if (tid == 0)
303-
s_variance = hrsqrt(variance / (K_ + epsilon));
303+
s_variance = hrsqrt(variance / K_ + epsilon);
304304
// wait for s_variance in shared memory to be ready for all threads
305305
__syncthreads();
306306
// manual unroll
@@ -390,7 +390,7 @@ __global__ void layer_norm_f16x8_pack_f16_kernel(half *x, half *y, float g,
390390
}
391391
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
392392
if (tid == 0)
393-
s_variance = hrsqrt(variance / (K_ + epsilon));
393+
s_variance = hrsqrt(variance / K_ + epsilon);
394394
// wait for s_variance in shared memory to be ready for all threads
395395
__syncthreads();
396396

kernels/layer-norm/layer_norm.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,18 @@ def run_benchmark(
9696
run_benchmark(naive_layer_norm, x_f16, "f16_th")
9797
print("-" * 85)
9898

99+
print(" " * 40 + f"f16 overflow without f32")
100+
print("-" * 85)
101+
x_f16 = x.half() * 100 # this will cause overflow for kernels without `f32`
102+
run_benchmark(lib.layer_norm_f16_f16, x_f16, "f16f16", out_f16)
103+
run_benchmark(lib.layer_norm_f16_f32, x_f16, "f16f32", out_f16)
104+
run_benchmark(lib.layer_norm_f16x2_f16, x_f16, "f16x2f16", out_f16)
105+
run_benchmark(lib.layer_norm_f16x8_f16, x_f16, "f16x8f16", out_f16)
106+
run_benchmark(lib.layer_norm_f16x8_pack_f16, x_f16, "f16x8packf16", out_f16)
107+
run_benchmark(lib.layer_norm_f16x8_pack_f32, x_f16, "f16x8packf32", out_f16)
108+
run_benchmark(naive_layer_norm, x_f16, "f16_th")
109+
print("-" * 85)
110+
99111
print("-" * 85)
100112
N, K = 4096, 1024
101113
print(" " * 40 + f"N={N}, K={K}")

kernels/rms-norm/rms_norm.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ __global__ void rms_norm_f16_f16_kernel(half *x, half *y, float g, int N,
172172
half variance = value * value;
173173
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
174174
if (tid == 0)
175-
s_variance = hrsqrt(variance / (K_ + epsilon));
175+
s_variance = hrsqrt(variance / K_ + epsilon);
176176
// wait for s_variance in shared memory to be ready for all threads
177177
__syncthreads();
178178
if (idx < N * K)
@@ -195,7 +195,7 @@ __global__ void rms_norm_f16x2_f16_kernel(half *x, half *y, float g, int N,
195195
: __float2half(0.0f);
196196
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
197197
if (tid == 0)
198-
s_variance = hrsqrt(variance / (K_ + epsilon));
198+
s_variance = hrsqrt(variance / K_ + epsilon);
199199
// wait for s_variance in shared memory to be ready for all threads
200200
__syncthreads();
201201
half2 reg_y;
@@ -241,7 +241,7 @@ __global__ void rms_norm_f16x8_f16_kernel(half *x, half *y, float g, int N,
241241
variance += HALF2_VARIANCE(reg_x_3, 6);
242242
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
243243
if (tid == 0)
244-
s_variance = hrsqrt(variance / (K_ + epsilon));
244+
s_variance = hrsqrt(variance / K_ + epsilon);
245245
// wait for s_variance in shared memory to be ready for all threads
246246
__syncthreads();
247247
// manual unroll
@@ -292,7 +292,7 @@ __global__ void rms_norm_f16x8_f32_kernel(half *x, half *y, float g, int N,
292292

293293
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
294294
if (tid == 0)
295-
s_variance = rsqrtf(variance / ((float)K + epsilon));
295+
s_variance = rsqrtf(variance / (float)K + epsilon);
296296
// wait for s_variance in shared memory to be ready for all threads
297297
__syncthreads();
298298
// manual unroll
@@ -328,7 +328,7 @@ __global__ void rms_norm_f16_f32_kernel(half *x, half *y, float g, int N,
328328
float variance = value * value;
329329
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
330330
if (tid == 0)
331-
s_variance = rsqrtf(variance / ((float)K + epsilon));
331+
s_variance = rsqrtf(variance / (float)K + epsilon);
332332
// wait for s_variance in shared memory to be ready for all threads
333333
__syncthreads();
334334
if (idx < N * K) {
@@ -360,7 +360,7 @@ __global__ void rms_norm_f16x8_pack_f16_kernel(half *x, half *y, float g, int N,
360360
}
361361
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
362362
if (tid == 0)
363-
s_variance = hrsqrt(variance / (K_ + epsilon));
363+
s_variance = hrsqrt(variance / K_ + epsilon);
364364
// wait for s_variance in shared memory to be ready for all threads
365365
__syncthreads();
366366

@@ -396,7 +396,7 @@ __global__ void rms_norm_f16x8_pack_f32_kernel(half *x, half *y, float g, int N,
396396
}
397397
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
398398
if (tid == 0)
399-
s_variance = rsqrtf(variance / ((float)K + epsilon));
399+
s_variance = rsqrtf(variance / (float)K + epsilon);
400400
// wait for s_variance in shared memory to be ready for all threads
401401
__syncthreads();
402402

@@ -626,7 +626,7 @@ void rms_norm_f32x4(torch::Tensor x, torch::Tensor y, float g) {
626626
}
627627

628628
#define LANUCH_RMS_NORM_F16x8F32_KERNEL(K) \
629-
rms_norm_f16x8_f16_kernel<(K) / 8> \
629+
rms_norm_f16x8_f32_kernel<(K) / 8> \
630630
<<<grid, block>>>(reinterpret_cast<half *>(x.data_ptr()), \
631631
reinterpret_cast<half *>(y.data_ptr()), g, N, (K));
632632

kernels/rms-norm/rms_norm.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,19 @@ def run_benchmark(
9595
run_benchmark(naive_rms_norm, x_f16, "f16_th")
9696
print("-" * 85)
9797

98+
print(" " * 40 + f"f16 overflow without f32")
99+
print("-" * 85)
100+
x_f16 = x.half() * 100 # this will cause overflow for kernels without `f32`
101+
run_benchmark(lib.rms_norm_f16_f16, x_f16, "f16f16", out_f16)
102+
run_benchmark(lib.rms_norm_f16_f32, x_f16, "f16f32", out_f16)
103+
run_benchmark(lib.rms_norm_f16x2_f16, x_f16, "f16x2f16", out_f16)
104+
run_benchmark(lib.rms_norm_f16x8_f16, x_f16, "f16x8f16", out_f16)
105+
run_benchmark(lib.rms_norm_f16x8_f32, x_f16, "f16x8f32", out_f16)
106+
run_benchmark(lib.rms_norm_f16x8_pack_f16, x_f16, "f16x8packf16", out_f16)
107+
run_benchmark(lib.rms_norm_f16x8_pack_f32, x_f16, "f16x8packf32", out_f16)
108+
run_benchmark(naive_rms_norm, x_f16, "f16_th")
109+
print("-" * 85)
110+
98111
print("-" * 85)
99112
N, K = 4096, 1024
100113
print(" " * 40 + f"N={N}, K={K}")

0 commit comments

Comments
 (0)