Skip to content

Commit 303987a

Browse files
committed
Optimize rms_norm_kernel by removing redundant global memory loads
Signed-off-by: Benji Beck <[email protected]>
1 parent 554df8a commit 303987a

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

csrc/layernorm_kernels.cu

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
namespace vllm {
1414

15-
// TODO(woosuk): Further optimize this kernel.
1615
template <typename scalar_t>
1716
__global__ void rms_norm_kernel(
1817
scalar_t* __restrict__ out, // [..., hidden_size]
@@ -21,10 +20,15 @@ __global__ void rms_norm_kernel(
2120
const scalar_t* __restrict__ weight, // [hidden_size]
2221
const float epsilon, const int num_tokens, const int hidden_size) {
2322
__shared__ float s_variance;
23+
// Assumes hidden_size <= 1024, enforced via launch config
24+
__shared__ float s_input[1024];
25+
2426
float variance = 0.0f;
27+
int row_offset = blockIdx.x * input_stride;
2528

2629
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
27-
const float x = (float)input[blockIdx.x * input_stride + idx];
30+
float x = (float)input[row_offset + idx];
31+
s_input[idx] = x;
2832
variance += x * x;
2933
}
3034

@@ -38,7 +42,7 @@ __global__ void rms_norm_kernel(
3842
__syncthreads();
3943

4044
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
41-
float x = (float)input[blockIdx.x * input_stride + idx];
45+
float x = s_input[idx];
4246
out[blockIdx.x * hidden_size + idx] =
4347
((scalar_t)(x * s_variance)) * weight[idx];
4448
}

0 commit comments

Comments
 (0)