Skip to content

Commit e04ac2f

Browse files
committed
Fix rms_norm_kernel to use dynamic shared memory for large hidden_size
Signed-off-by: Benji Beck <[email protected]>
1 parent 303987a commit e04ac2f

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

csrc/layernorm_kernels.cu

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ __global__ void rms_norm_kernel(
2020
const scalar_t* __restrict__ weight, // [hidden_size]
2121
const float epsilon, const int num_tokens, const int hidden_size) {
2222
__shared__ float s_variance;
23-
// Assumes hidden_size <= 1024, enforced via launch config
24-
__shared__ float s_input[1024];
23+
// Dynamic shared memory allocation for input values
24+
extern __shared__ float s_input[];
2525

2626
float variance = 0.0f;
2727
int row_offset = blockIdx.x * input_stride;
@@ -162,8 +162,12 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
162162
dim3 block(std::min(hidden_size, 1024));
163163
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
164164
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
165+
166+
// Calculate shared memory size for dynamic allocation
167+
size_t shared_mem_bytes = hidden_size * sizeof(float);
168+
165169
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
166-
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
170+
vllm::rms_norm_kernel<scalar_t><<<grid, block, shared_mem_bytes, stream>>>(
167171
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
168172
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
169173
});

0 commit comments

Comments
 (0)