diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index f051eb070222..12692a238006 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -12,7 +12,6 @@ namespace vllm { -// TODO(woosuk): Further optimize this kernel. template __global__ void rms_norm_kernel( scalar_t* __restrict__ out, // [..., hidden_size] @@ -21,10 +20,15 @@ __global__ void rms_norm_kernel( const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; + // Dynamic shared memory allocation for input values + extern __shared__ float s_input[]; + float variance = 0.0f; + int row_offset = blockIdx.x * input_stride; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float)input[blockIdx.x * input_stride + idx]; + float x = (float)input[row_offset + idx]; + s_input[idx] = x; variance += x * x; } @@ -38,7 +42,7 @@ __global__ void rms_norm_kernel( __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * input_stride + idx]; + float x = s_input[idx]; out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance)) * weight[idx]; } @@ -158,8 +162,12 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Calculate shared memory size for dynamic allocation + size_t shared_mem_bytes = hidden_size * sizeof(float); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { - vllm::rms_norm_kernel<<>>( + vllm::rms_norm_kernel<<>>( out.data_ptr(), input.data_ptr(), input_stride, weight.data_ptr(), epsilon, num_tokens, hidden_size); });