File tree Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Original file line number Diff line number Diff line change 12
12
13
13
namespace vllm {
14
14
15
- // TODO(woosuk): Further optimize this kernel.
16
15
template <typename scalar_t >
17
16
__global__ void rms_norm_kernel (
18
17
scalar_t * __restrict__ out, // [..., hidden_size]
@@ -21,10 +20,15 @@ __global__ void rms_norm_kernel(
21
20
const scalar_t * __restrict__ weight, // [hidden_size]
22
21
const float epsilon, const int num_tokens, const int hidden_size) {
23
22
__shared__ float s_variance;
23
+ // Assumes hidden_size <= 1024, enforced via launch config
24
+ __shared__ float s_input[1024 ];
25
+
24
26
float variance = 0 .0f ;
27
+ int row_offset = blockIdx .x * input_stride;
25
28
26
29
for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
27
- const float x = (float )input[blockIdx .x * input_stride + idx];
30
+ float x = (float )input[row_offset + idx];
31
+ s_input[idx] = x;
28
32
variance += x * x;
29
33
}
30
34
@@ -38,7 +42,7 @@ __global__ void rms_norm_kernel(
38
42
__syncthreads ();
39
43
40
44
for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
41
- float x = ( float )input[ blockIdx . x * input_stride + idx];
45
+ float x = s_input[ idx];
42
46
out[blockIdx .x * hidden_size + idx] =
43
47
((scalar_t )(x * s_variance)) * weight[idx];
44
48
}
You can’t perform that action at this time.
0 commit comments