@@ -172,7 +172,7 @@ __global__ void rms_norm_f16_f16_kernel(half *x, half *y, float g, int N,
172
172
half variance = value * value;
173
173
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
174
174
if (tid == 0 )
175
- s_variance = hrsqrt (variance / ( K_ + epsilon) );
175
+ s_variance = hrsqrt (variance / K_ + epsilon);
176
176
// wait for s_variance in shared memory to be ready for all threads
177
177
__syncthreads ();
178
178
if (idx < N * K)
@@ -195,7 +195,7 @@ __global__ void rms_norm_f16x2_f16_kernel(half *x, half *y, float g, int N,
195
195
: __float2half (0 .0f );
196
196
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
197
197
if (tid == 0 )
198
- s_variance = hrsqrt (variance / ( K_ + epsilon) );
198
+ s_variance = hrsqrt (variance / K_ + epsilon);
199
199
// wait for s_variance in shared memory to be ready for all threads
200
200
__syncthreads ();
201
201
half2 reg_y;
@@ -241,7 +241,7 @@ __global__ void rms_norm_f16x8_f16_kernel(half *x, half *y, float g, int N,
241
241
variance += HALF2_VARIANCE (reg_x_3, 6 );
242
242
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
243
243
if (tid == 0 )
244
- s_variance = hrsqrt (variance / ( K_ + epsilon) );
244
+ s_variance = hrsqrt (variance / K_ + epsilon);
245
245
// wait for s_variance in shared memory to be ready for all threads
246
246
__syncthreads ();
247
247
// manual unroll
@@ -292,7 +292,7 @@ __global__ void rms_norm_f16x8_f32_kernel(half *x, half *y, float g, int N,
292
292
293
293
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
294
294
if (tid == 0 )
295
- s_variance = rsqrtf (variance / (( float )K + epsilon) );
295
+ s_variance = rsqrtf (variance / (float )K + epsilon);
296
296
// wait for s_variance in shared memory to be ready for all threads
297
297
__syncthreads ();
298
298
// manual unroll
@@ -328,7 +328,7 @@ __global__ void rms_norm_f16_f32_kernel(half *x, half *y, float g, int N,
328
328
float variance = value * value;
329
329
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
330
330
if (tid == 0 )
331
- s_variance = rsqrtf (variance / (( float )K + epsilon) );
331
+ s_variance = rsqrtf (variance / (float )K + epsilon);
332
332
// wait for s_variance in shared memory to be ready for all threads
333
333
__syncthreads ();
334
334
if (idx < N * K) {
@@ -360,7 +360,7 @@ __global__ void rms_norm_f16x8_pack_f16_kernel(half *x, half *y, float g, int N,
360
360
}
361
361
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
362
362
if (tid == 0 )
363
- s_variance = hrsqrt (variance / ( K_ + epsilon) );
363
+ s_variance = hrsqrt (variance / K_ + epsilon);
364
364
// wait for s_variance in shared memory to be ready for all threads
365
365
__syncthreads ();
366
366
@@ -396,7 +396,7 @@ __global__ void rms_norm_f16x8_pack_f32_kernel(half *x, half *y, float g, int N,
396
396
}
397
397
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
398
398
if (tid == 0 )
399
- s_variance = rsqrtf (variance / (( float )K + epsilon) );
399
+ s_variance = rsqrtf (variance / (float )K + epsilon);
400
400
// wait for s_variance in shared memory to be ready for all threads
401
401
__syncthreads ();
402
402
@@ -626,7 +626,7 @@ void rms_norm_f32x4(torch::Tensor x, torch::Tensor y, float g) {
626
626
}
627
627
628
628
#define LANUCH_RMS_NORM_F16x8F32_KERNEL (K ) \
629
- rms_norm_f16x8_f16_kernel <(K) / 8 > \
629
+ rms_norm_f16x8_f32_kernel <(K) / 8 > \
630
630
<<<grid, block>>> (reinterpret_cast <half *>(x.data_ptr()), \
631
631
reinterpret_cast <half *>(y.data_ptr()), g, N, (K));
632
632
0 commit comments