@@ -20,8 +20,8 @@ __global__ void rms_norm_kernel(
20
20
const scalar_t * __restrict__ weight, // [hidden_size]
21
21
const float epsilon, const int num_tokens, const int hidden_size) {
22
22
__shared__ float s_variance;
23
- // Assumes hidden_size <= 1024, enforced via launch config
24
- __shared__ float s_input[1024 ];
23
+ // Dynamic shared memory allocation for input values
24
+ extern __shared__ float s_input[];
25
25
26
26
float variance = 0 .0f ;
27
27
int row_offset = blockIdx .x * input_stride;
@@ -162,8 +162,12 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
162
162
dim3 block (std::min (hidden_size, 1024 ));
163
163
const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
164
164
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
165
+
166
+ // Calculate shared memory size for dynamic allocation
167
+ size_t shared_mem_bytes = hidden_size * sizeof (float );
168
+
165
169
VLLM_DISPATCH_FLOATING_TYPES (input.scalar_type (), " rms_norm_kernel" , [&] {
166
- vllm::rms_norm_kernel<scalar_t ><<<grid, block, 0 , stream>>> (
170
+ vllm::rms_norm_kernel<scalar_t ><<<grid, block, shared_mem_bytes , stream>>> (
167
171
out.data_ptr <scalar_t >(), input.data_ptr <scalar_t >(), input_stride,
168
172
weight.data_ptr <scalar_t >(), epsilon, num_tokens, hidden_size);
169
173
});
0 commit comments