@@ -70,7 +70,7 @@ __global__ void layer_norm_f32_kernel(float *x, float *y, float g, float b,
70
70
float variance = (value - s_mean) * (value - s_mean);
71
71
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
72
72
if (tid == 0 )
73
- s_variance = rsqrtf (variance / (( float )K + epsilon) );
73
+ s_variance = rsqrtf (variance / (float )K + epsilon);
74
74
// wait for s_variance in shared memory to be ready for all threads
75
75
__syncthreads ();
76
76
if (idx < N * K)
@@ -107,7 +107,7 @@ __global__ void layer_norm_f32x4_kernel(float *x, float *y, float g, float b,
107
107
reg_x_hat.z * reg_x_hat.z + reg_x_hat.w * reg_x_hat.w ;
108
108
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
109
109
if (tid == 0 )
110
- s_variance = rsqrtf (variance / (( float )K + epsilon) );
110
+ s_variance = rsqrtf (variance / (float )K + epsilon);
111
111
// wait for s_variance in shared memory to be ready for all threads
112
112
__syncthreads ();
113
113
float4 reg_y;
@@ -343,7 +343,7 @@ __global__ void layer_norm_f16_f32_kernel(half *x, half *y, float g, float b,
343
343
float variance = (value - s_mean) * (value - s_mean);
344
344
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
345
345
if (tid == 0 )
346
- s_variance = rsqrtf (variance / (( float )K + epsilon) );
346
+ s_variance = rsqrtf (variance / (float )K + epsilon);
347
347
// wait for s_variance in shared memory to be ready for all threads
348
348
__syncthreads ();
349
349
if (idx < N * K) {
@@ -440,7 +440,7 @@ __global__ void layer_norm_f16x8_pack_f32_kernel(half *x, half *y, float g,
440
440
}
441
441
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
442
442
if (tid == 0 )
443
- s_variance = rsqrtf (variance / (( float )K + epsilon) );
443
+ s_variance = rsqrtf (variance / (float )K + epsilon);
444
444
// wait for s_variance in shared memory to be ready for all threads
445
445
__syncthreads ();
446
446
0 commit comments