Skip to content

Commit 922403e

Browse files
committed
Optimize rms_norm_kernel by removing redundant global memory loads
Signed-off-by: Benji Beck <[email protected]>
1 parent 554df8a commit 922403e

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

csrc/layernorm_kernels.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
namespace vllm {
1414

15-
// TODO(woosuk): Further optimize this kernel.
1615
template <typename scalar_t>
1716
__global__ void rms_norm_kernel(
1817
scalar_t* __restrict__ out, // [..., hidden_size]
@@ -21,10 +20,14 @@ __global__ void rms_norm_kernel(
2120
const scalar_t* __restrict__ weight, // [hidden_size]
2221
const float epsilon, const int num_tokens, const int hidden_size) {
2322
__shared__ float s_variance;
23+
__shared__ float s_input[1024];
24+
2425
float variance = 0.0f;
26+
int row_offset = blockIdx.x * input_stride;
2527

2628
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
27-
const float x = (float)input[blockIdx.x * input_stride + idx];
29+
float x = (float)input[row_offset + idx];
30+
s_input[idx] = x;
2831
variance += x * x;
2932
}
3033

@@ -38,7 +41,7 @@ __global__ void rms_norm_kernel(
3841
__syncthreads();
3942

4043
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
41-
float x = (float)input[blockIdx.x * input_stride + idx];
44+
float x = s_input[idx];
4245
out[blockIdx.x * hidden_size + idx] =
4346
((scalar_t)(x * s_variance)) * weight[idx];
4447
}

0 commit comments

Comments
 (0)