Skip to content

Update rms_norm_kernel by removing redundant global memory loads #22134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

namespace vllm {

// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
Expand All @@ -21,10 +20,15 @@ __global__ void rms_norm_kernel(
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
// Dynamic shared memory allocation for input values
extern __shared__ float s_input[];

float variance = 0.0f;
int row_offset = blockIdx.x * input_stride;

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * input_stride + idx];
float x = (float)input[row_offset + idx];
s_input[idx] = x;
variance += x * x;
}

Expand All @@ -38,7 +42,7 @@ __global__ void rms_norm_kernel(
__syncthreads();

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * input_stride + idx];
float x = s_input[idx];
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
}
Expand Down Expand Up @@ -158,8 +162,12 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// Calculate shared memory size for dynamic allocation
size_t shared_mem_bytes = hidden_size * sizeof(float);

VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
vllm::rms_norm_kernel<scalar_t><<<grid, block, shared_mem_bytes, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
});
Expand Down