@@ -34,6 +34,36 @@ __global__ void rms_norm_kernel(
34
34
}
35
35
}
36
36
37
+ // TODO: Further optimize this kernel.
38
+ template <typename scalar_t >
39
+ __global__ void fused_add_rms_norm_kernel (
40
+ scalar_t * __restrict__ input, // [..., hidden_size]
41
+ scalar_t * __restrict__ residual, // [..., hidden_size]
42
+ const scalar_t * __restrict__ weight, // [hidden_size]
43
+ const float epsilon,
44
+ const int num_tokens,
45
+ const int hidden_size) {
46
+ __shared__ float s_variance;
47
+ float variance = 0 .0f ;
48
+
49
+ for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
50
+ float x = (float ) input[blockIdx .x * hidden_size + idx];
51
+ x += (float ) residual[blockIdx .x * hidden_size + idx];
52
+ variance += x * x;
53
+ residual[blockIdx .x * hidden_size + idx] = (scalar_t ) x;
54
+ }
55
+ variance = blockReduceSum<float >(variance);
56
+ if (threadIdx .x == 0 ) {
57
+ s_variance = rsqrtf (variance / hidden_size + epsilon);
58
+ }
59
+ __syncthreads ();
60
+
61
+ for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
62
+ float x = (float ) residual[blockIdx .x * hidden_size + idx];
63
+ input[blockIdx .x * hidden_size + idx] = ((scalar_t ) (x * s_variance)) * weight[idx];
64
+ }
65
+ }
66
+
37
67
} // namespace vllm
38
68
39
69
void rms_norm (
@@ -60,3 +90,28 @@ void rms_norm(
60
90
hidden_size);
61
91
});
62
92
}
93
+
94
+ void fused_add_rms_norm (
95
+ torch::Tensor& input, // [..., hidden_size]
96
+ torch::Tensor& residual, // [..., hidden_size]
97
+ torch::Tensor& weight, // [hidden_size]
98
+ float epsilon) {
99
+ int hidden_size = input.size (-1 );
100
+ int num_tokens = input.numel () / hidden_size;
101
+
102
+ dim3 grid (num_tokens);
103
+ dim3 block (std::min (hidden_size, 1024 ));
104
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
105
+ VLLM_DISPATCH_FLOATING_TYPES (
106
+ input.scalar_type (),
107
+ " fused_add_rms_norm_kernel" ,
108
+ [&] {
109
+ vllm::fused_add_rms_norm_kernel<scalar_t ><<<grid, block, 0 , stream>>> (
110
+ input.data_ptr <scalar_t >(),
111
+ residual.data_ptr <scalar_t >(),
112
+ weight.data_ptr <scalar_t >(),
113
+ epsilon,
114
+ num_tokens,
115
+ hidden_size);
116
+ });
117
+ }
0 commit comments