File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change 12
12
13
13
namespace vllm {
14
14
15
- // TODO(woosuk): Further optimize this kernel.
16
15
template <typename scalar_t >
17
16
__global__ void rms_norm_kernel (
18
17
scalar_t * __restrict__ out, // [..., hidden_size]
@@ -21,10 +20,14 @@ __global__ void rms_norm_kernel(
21
20
const scalar_t * __restrict__ weight, // [hidden_size]
22
21
const float epsilon, const int num_tokens, const int hidden_size) {
23
22
__shared__ float s_variance;
23
+ __shared__ float s_input[1024 ];
24
+
24
25
float variance = 0 .0f ;
26
+ int row_offset = blockIdx .x * input_stride;
25
27
26
28
for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
27
- const float x = (float )input[blockIdx .x * input_stride + idx];
29
+ float x = (float )input[row_offset + idx];
30
+ s_input[idx] = x;
28
31
variance += x * x;
29
32
}
30
33
@@ -38,7 +41,7 @@ __global__ void rms_norm_kernel(
38
41
__syncthreads ();
39
42
40
43
for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
41
- float x = ( float )input[ blockIdx . x * input_stride + idx];
44
+ float x = s_input[ idx];
42
45
out[blockIdx .x * hidden_size + idx] =
43
46
((scalar_t )(x * s_variance)) * weight[idx];
44
47
}
You can’t perform that action at this time.
0 commit comments