Skip to content

Commit 7b8f319

Browse files
authored
fix(kernels): correct typo in LayerNorm kernel at line 73 110 346 443 (#317)
1 parent 7aa190c commit 7b8f319

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

kernels/layer-norm/layer_norm.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ __global__ void layer_norm_f32_kernel(float *x, float *y, float g, float b,
7070
float variance = (value - s_mean) * (value - s_mean);
7171
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
7272
if (tid == 0)
73-
s_variance = rsqrtf(variance / ((float)K + epsilon));
73+
s_variance = rsqrtf(variance / (float)K + epsilon);
7474
// wait for s_variance in shared memory to be ready for all threads
7575
__syncthreads();
7676
if (idx < N * K)
@@ -107,7 +107,7 @@ __global__ void layer_norm_f32x4_kernel(float *x, float *y, float g, float b,
107107
reg_x_hat.z * reg_x_hat.z + reg_x_hat.w * reg_x_hat.w;
108108
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
109109
if (tid == 0)
110-
s_variance = rsqrtf(variance / ((float)K + epsilon));
110+
s_variance = rsqrtf(variance / (float)K + epsilon);
111111
// wait for s_variance in shared memory to be ready for all threads
112112
__syncthreads();
113113
float4 reg_y;
@@ -343,7 +343,7 @@ __global__ void layer_norm_f16_f32_kernel(half *x, half *y, float g, float b,
343343
float variance = (value - s_mean) * (value - s_mean);
344344
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
345345
if (tid == 0)
346-
s_variance = rsqrtf(variance / ((float)K + epsilon));
346+
s_variance = rsqrtf(variance / (float)K + epsilon);
347347
// wait for s_variance in shared memory to be ready for all threads
348348
__syncthreads();
349349
if (idx < N * K) {
@@ -440,7 +440,7 @@ __global__ void layer_norm_f16x8_pack_f32_kernel(half *x, half *y, float g,
440440
}
441441
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
442442
if (tid == 0)
443-
s_variance = rsqrtf(variance / ((float)K + epsilon));
443+
s_variance = rsqrtf(variance / (float)K + epsilon);
444444
// wait for s_variance in shared memory to be ready for all threads
445445
__syncthreads();
446446

0 commit comments

Comments
 (0)