@@ -70,7 +70,7 @@ __global__ void layer_norm_f32_kernel(float *x, float *y, float g, float b,
7070 float variance = (value - s_mean) * (value - s_mean);
7171 variance = block_reduce_sum_f32<NUM_THREADS>(variance);
7272 if (tid == 0 )
73- s_variance = rsqrtf (variance / (( float )K + epsilon) );
73+ s_variance = rsqrtf (variance / (float )K + epsilon);
7474 // wait for s_variance in shared memory to be ready for all threads
7575 __syncthreads ();
7676 if (idx < N * K)
@@ -107,7 +107,7 @@ __global__ void layer_norm_f32x4_kernel(float *x, float *y, float g, float b,
107107 reg_x_hat.z * reg_x_hat.z + reg_x_hat.w * reg_x_hat.w ;
108108 variance = block_reduce_sum_f32<NUM_THREADS>(variance);
109109 if (tid == 0 )
110- s_variance = rsqrtf (variance / (( float )K + epsilon) );
110+ s_variance = rsqrtf (variance / (float )K + epsilon);
111111 // wait for s_variance in shared memory to be ready for all threads
112112 __syncthreads ();
113113 float4 reg_y;
@@ -343,7 +343,7 @@ __global__ void layer_norm_f16_f32_kernel(half *x, half *y, float g, float b,
343343 float variance = (value - s_mean) * (value - s_mean);
344344 variance = block_reduce_sum_f32<NUM_THREADS>(variance);
345345 if (tid == 0 )
346- s_variance = rsqrtf (variance / (( float )K + epsilon) );
346+ s_variance = rsqrtf (variance / (float )K + epsilon);
347347 // wait for s_variance in shared memory to be ready for all threads
348348 __syncthreads ();
349349 if (idx < N * K) {
@@ -440,7 +440,7 @@ __global__ void layer_norm_f16x8_pack_f32_kernel(half *x, half *y, float g,
440440 }
441441 variance = block_reduce_sum_f32<NUM_THREADS>(variance);
442442 if (tid == 0 )
443- s_variance = rsqrtf (variance / (( float )K + epsilon) );
443+ s_variance = rsqrtf (variance / (float )K + epsilon);
444444 // wait for s_variance in shared memory to be ready for all threads
445445 __syncthreads ();
446446
0 commit comments