3
3
4
4
#include < torch/cuda.h>
5
5
#include < c10/cuda/CUDAGuard.h>
6
+ #include " quantization/vectorization_utils.cuh"
6
7
7
8
#ifndef USE_ROCM
8
9
#include < cub/cub.cuh>
12
13
13
14
namespace vllm {
14
15
15
- // TODO(woosuk): Further optimize this kernel.
16
+ // ---------- helpers ----------
17
+ template <typename T>
18
+ __device__ __forceinline__ T warp_sum (T v) {
19
+ #ifdef __HIP_PLATFORM_AMD__
20
+ const unsigned long long m = 0xffffffffffffffffull ; // HIP needs 64-bit mask
21
+ #else
22
+ const unsigned m = 0xffffffffu ; // CUDA 32-bit mask
23
+ #endif
24
+ // Always reduce over 32 lanes to match downstream logic.
25
+ constexpr int kWidth = 32 ;
26
+ v += __shfl_down_sync (m, v, 16 , kWidth );
27
+ v += __shfl_down_sync (m, v, 8 , kWidth );
28
+ v += __shfl_down_sync (m, v, 4 , kWidth );
29
+ v += __shfl_down_sync (m, v, 2 , kWidth );
30
+ v += __shfl_down_sync (m, v, 1 , kWidth );
31
+ return v;
32
+ }
33
+
34
+ template <typename T>
35
+ __device__ __forceinline__ bool same_phase (const T* a, const T* b, int widthB) {
36
+ auto ai = reinterpret_cast <uintptr_t >(a);
37
+ auto bi = reinterpret_cast <uintptr_t >(b);
38
+ return ((ai ^ bi) & (widthB - 1 )) == 0 ;
39
+ }
40
+
41
+ // Safe 16B copy to shared: prefix to align, vector main, scalar tail.
42
+ template <typename T>
43
+ __device__ __forceinline__ void copy_row_to_shared_aligned (
44
+ const T* __restrict__ src, T* __restrict__ dst, int n_elems, int tid) {
45
+ const uintptr_t sa = reinterpret_cast <uintptr_t >(src);
46
+ const uintptr_t da = reinterpret_cast <uintptr_t >(dst);
47
+ const bool same16 = (((sa ^ da) & (16 - 1 )) == 0 );
48
+
49
+ if (!same16) {
50
+ for (int i = tid; i < n_elems; i += blockDim .x ) dst[i] = src[i];
51
+ __syncthreads ();
52
+ return ;
53
+ }
54
+ const int ebytes = sizeof (T);
55
+ const int perVec = 16 / ebytes;
56
+
57
+ int prefix = 0 ;
58
+ const int mis = sa & (16 - 1 );
59
+ if (mis) prefix = (16 - mis) / ebytes;
60
+ if (prefix > n_elems) prefix = n_elems;
61
+
62
+ // scalar prefix
63
+ for (int i = tid; i < prefix; i += blockDim .x ) dst[i] = src[i];
64
+
65
+ // vector main
66
+ const int remain = n_elems - prefix;
67
+ const int main_elems = (remain / perVec) * perVec;
68
+ if (main_elems > 0 ) {
69
+ const uint4 * __restrict__ vsrc =
70
+ reinterpret_cast <const uint4 *>(src + prefix);
71
+
72
+ #if defined(__HIP_PLATFORM_AMD__)
73
+ // ROCm: vector load from global, scalar 32-bit stores to shared
74
+ uint32_t * __restrict__ s32 =
75
+ reinterpret_cast <uint32_t *>(dst + prefix);
76
+ const int nvec = main_elems / perVec; // 16B packets
77
+ constexpr int WORDS_PER_PKT = 16 / sizeof (uint32_t ); // = 4
78
+ for (int v = tid; v < nvec; v += blockDim .x ) {
79
+ uint4 p = vsrc[v];
80
+ const int base = v * WORDS_PER_PKT;
81
+ s32[base + 0 ] = p.x ;
82
+ s32[base + 1 ] = p.y ;
83
+ s32[base + 2 ] = p.z ;
84
+ s32[base + 3 ] = p.w ;
85
+ }
86
+ #else
87
+ // CUDA: vector load + vector store (fastest)
88
+ uint4 * __restrict__ vdst =
89
+ reinterpret_cast <uint4 *>(dst + prefix);
90
+ const int nvec = main_elems / perVec;
91
+ for (int v = tid; v < nvec; v += blockDim .x ) {
92
+ uint4 p = vsrc[v];
93
+ vdst[v] = p;
94
+ }
95
+ #endif
96
+ }
97
+
98
+ // scalar tail
99
+ const int tail = prefix + main_elems;
100
+ for (int i = tid + tail; i < n_elems; i += blockDim .x ) dst[i] = src[i];
101
+ __syncthreads ();
102
+ }
103
+
104
+ // ---------------- vec/scalar ops (generic, used for all dtypes) ----------------
105
+ template <int V, typename T>
106
+ struct VecMulNormWeight {
107
+ const vec_n_t <T, V>* __restrict__ wv; // vector view of weight (aligned with in/out)
108
+ float inv_rms;
109
+ int stride_vec;
110
+ mutable int64_t vec_idx;
111
+
112
+ __device__ __forceinline__ void operator ()(vec_n_t <T, V>& dst,
113
+ const vec_n_t <T, V>& src) const {
114
+ const vec_n_t <T, V> w = wv[vec_idx];
115
+ #pragma unroll
116
+ for (int j = 0 ; j < V; ++j) {
117
+ T xn = static_cast <T>(static_cast <float >(src.val [j]) * inv_rms);
118
+ dst.val [j] = xn * w.val [j];
119
+ }
120
+ vec_idx += stride_vec;
121
+ }
122
+ };
123
+
124
+ template <typename T>
125
+ struct ScalarMulNormWeight {
126
+ const T* __restrict__ w_base; // already offset by +prefix
127
+ T* __restrict__ out_base; // out_row + prefix
128
+ float inv_rms;
129
+ __device__ __forceinline__ void operator ()(T& dst, const T src) const {
130
+ const int i = static_cast <int >(&dst - out_base);
131
+ T xn = static_cast <T>(static_cast <float >(src) * inv_rms);
132
+ dst = xn * w_base[i];
133
+ }
134
+ };
135
+
136
+ // ---------- kernel ----------
16
137
template <typename scalar_t >
17
138
__global__ void rms_norm_kernel (
18
139
scalar_t * __restrict__ out, // [..., hidden_size]
19
140
const scalar_t * __restrict__ input, // [..., hidden_size]
20
141
const int64_t input_stride,
21
- const scalar_t * __restrict__ weight, // [hidden_size]
22
- const float epsilon, const int num_tokens, const int hidden_size) {
23
- __shared__ float s_variance;
24
- float variance = 0 .0f ;
142
+ const scalar_t * __restrict__ weight, // [hidden_size]
143
+ const float epsilon, const int /* num_tokens*/ , const int hidden_size,
144
+ int smem_elems) {
25
145
26
- for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
27
- const float x = (float )input[blockIdx .x * input_stride + idx];
28
- variance += x * x;
146
+ const scalar_t * __restrict__ in_row = input + blockIdx .x * input_stride;
147
+ scalar_t * __restrict__ out_row = out + blockIdx .x * hidden_size;
148
+
149
+ // Optional cached-row (half) when host provisioned shmem
150
+ extern __shared__ unsigned char smem_raw[];
151
+ scalar_t * s_in = reinterpret_cast <scalar_t *>(smem_raw);
152
+
153
+ #ifdef __HIP_PLATFORM_AMD__
154
+ constexpr bool kAllowCache = false ;
155
+ #else
156
+ constexpr bool kAllowCache = true ;
157
+ #endif
158
+ const bool use_cached =
159
+ kAllowCache && (sizeof (scalar_t ) == 2 ) && (smem_elems > 0 );
160
+
161
+ #if !defined(__HIP_PLATFORM_AMD__)
162
+ if (use_cached) {
163
+ copy_row_to_shared_aligned (in_row, s_in, hidden_size, threadIdx .x );
29
164
}
165
+ #endif
30
166
31
- using BlockReduce = cub::BlockReduce<float , 1024 >;
32
- __shared__ typename BlockReduce::TempStorage reduceStore;
33
- variance = BlockReduce (reduceStore).Reduce (variance, cub::Sum{}, blockDim .x );
167
+ // -------- Pass 1: sum of squares --------
168
+ using acc_t = float ;
169
+ acc_t sumsq = acc_t (0 );
170
+ if (use_cached) {
171
+ for (int i = threadIdx .x ; i < hidden_size; i += blockDim .x ) {
172
+ acc_t x = static_cast <acc_t >(s_in[i]);
173
+ sumsq += x * x;
174
+ }
175
+ } else {
176
+ for (int i = threadIdx .x ; i < hidden_size; i += blockDim .x ) {
177
+ acc_t x = static_cast <acc_t >(in_row[i]);
178
+ sumsq += x * x;
179
+ }
180
+ }
34
181
35
- if (threadIdx .x == 0 ) {
36
- s_variance = rsqrtf (variance / hidden_size + epsilon);
182
+ // warp + block reduction in acc_t
183
+ acc_t wsum = warp_sum<acc_t >(sumsq);
184
+ __shared__ acc_t warp_sums_sh[32 ];
185
+ if ((threadIdx .x & 31 ) == 0 ) warp_sums_sh[threadIdx .x >> 5 ] = wsum;
186
+ __syncthreads ();
187
+
188
+ acc_t total = acc_t (0 );
189
+ if (threadIdx .x < 32 ) {
190
+ acc_t v = (threadIdx .x < (blockDim .x + 31 ) / 32 ) ? warp_sums_sh[threadIdx .x ] : acc_t (0 );
191
+ total = warp_sum<acc_t >(v);
192
+ if (threadIdx .x == 0 ) warp_sums_sh[0 ] = total;
37
193
}
38
194
__syncthreads ();
39
195
40
- for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
41
- float x = (float )input[blockIdx .x * input_stride + idx];
42
- out[blockIdx .x * hidden_size + idx] =
43
- ((scalar_t )(x * s_variance)) * weight[idx];
196
+ // compute inv_rms in float to match baseline epsilon semantics
197
+ const float inv_rms =
198
+ rsqrtf (static_cast <float >(warp_sums_sh[0 ] / static_cast <acc_t >(hidden_size)) + epsilon);
199
+
200
+ // -------- Fast path: HS == blockDim.x (e.g., 1024) --------
201
+ if (hidden_size == blockDim .x ) {
202
+ int i = threadIdx .x ;
203
+ float x = static_cast <float >(use_cached ? s_in[i] : in_row[i]);
204
+ scalar_t xn = static_cast <scalar_t >(x * inv_rms);
205
+ out_row[i] = xn * weight[i];
206
+ return ;
207
+ }
208
+
209
+ // -------- Pass 2: Vectorize when phases align --------
210
+ constexpr int V = (sizeof (scalar_t ) == 2 ) ? 8 : 4 ; // 16B packets
211
+ constexpr int WIDTH = V * sizeof (scalar_t );
212
+
213
+ const bool can_vec =
214
+ (hidden_size % V == 0 ) &&
215
+ same_phase (in_row, out_row, WIDTH) &&
216
+ same_phase (in_row, weight, WIDTH);
217
+
218
+ if (can_vec) {
219
+ const uintptr_t addr = reinterpret_cast <uintptr_t >(in_row);
220
+ const int mis = addr & (WIDTH - 1 );
221
+ const int prefix = mis ? (WIDTH - mis) / (int )sizeof (scalar_t ) : 0 ;
222
+
223
+ // scalar prefix
224
+ for (int i = threadIdx .x ; i < prefix; i += blockDim .x ) {
225
+ float x = static_cast <float >(use_cached ? s_in[i] : in_row[i]);
226
+ out_row[i] = static_cast <scalar_t >(x * inv_rms) * weight[i];
227
+ }
228
+
229
+ // vector main
230
+ const int remain = hidden_size - prefix;
231
+ const int main_len = (remain / V) * V;
232
+ if (main_len > 0 ) {
233
+ using VecT = vec_n_t <scalar_t , V>;
234
+ const VecT* __restrict__ wv =
235
+ reinterpret_cast <const VecT*>(weight + prefix);
236
+
237
+ VecMulNormWeight<V, scalar_t > vec_op{
238
+ /* wv=*/ wv,
239
+ /* inv_rms=*/ inv_rms,
240
+ /* stride_vec=*/ (int )blockDim .x ,
241
+ /* vec_idx=*/ (int64_t )threadIdx .x
242
+ };
243
+ ScalarMulNormWeight<scalar_t > sca_op{
244
+ /* w_base=*/ weight + prefix,
245
+ /* out_base=*/ out_row + prefix,
246
+ /* inv_rms=*/ inv_rms
247
+ };
248
+
249
+ const scalar_t * vin = use_cached ? (s_in + prefix) : (in_row + prefix);
250
+ vectorize_with_alignment<V>(
251
+ vin, out_row + prefix, main_len,
252
+ threadIdx .x , blockDim .x , vec_op, sca_op);
253
+ }
254
+
255
+ // scalar tail
256
+ const int tail = prefix + ((remain / V) * V);
257
+ for (int i = threadIdx .x + tail; i < hidden_size; i += blockDim .x ) {
258
+ float x = static_cast <float >(use_cached ? s_in[i] : in_row[i]);
259
+ out_row[i] = static_cast <scalar_t >(x * inv_rms) * weight[i];
260
+ }
261
+ return ;
262
+ }
263
+
264
+ // -------- Fallback scalar --------
265
+ for (int i = threadIdx .x ; i < hidden_size; i += blockDim .x ) {
266
+ float x = static_cast <float >(use_cached ? s_in[i] : in_row[i]);
267
+ scalar_t xn = static_cast <scalar_t >(x * inv_rms);
268
+ out_row[i] = xn * weight[i];
44
269
}
45
270
}
46
271
272
+
47
273
/* Function specialization in the case of FP16/BF16 tensors.
48
274
Additional optimizations we can make in this case are
49
275
packed and vectorized operations, which help with the
@@ -142,6 +368,13 @@ fused_add_rms_norm_kernel(
142
368
143
369
} // namespace vllm
144
370
371
+ static inline int ln_block_threads_unified (int H) {
372
+ int threads = (H >= 1024 ) ? 256
373
+ : (H >= 512 ) ? 512
374
+ : std::min (1024 , ((H + 31 ) / 32 ) * 32 );
375
+ return std::min (1024 , std::max (128 , ((threads + 31 ) / 32 ) * 32 ));
376
+ }
377
+
145
378
void rms_norm (torch::Tensor& out, // [..., hidden_size]
146
379
torch::Tensor& input, // [..., hidden_size]
147
380
torch::Tensor& weight, // [hidden_size]
@@ -150,21 +383,39 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
150
383
TORCH_CHECK (input.stride (-1 ) == 1 );
151
384
TORCH_CHECK (weight.is_contiguous ());
152
385
153
- int hidden_size = input.size (-1 );
154
- int num_tokens = input.numel () / hidden_size;
155
- int64_t input_stride = input.stride (-2 );
386
+ const int hidden_size = input.size (-1 );
387
+ const int num_tokens = input.numel () / hidden_size;
388
+ const int64_t in_stride = input.stride (-2 );
156
389
157
390
dim3 grid (num_tokens);
158
- dim3 block (std::min (hidden_size, 1024 ));
391
+ dim3 block (ln_block_threads_unified (hidden_size));
392
+
159
393
const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
160
394
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
395
+
396
+ // Optional cached-row for FP16 (recommended). Kernel still works if this is 0.
397
+ size_t shmem_bytes = 0 ;
398
+ int smem_elems = 0 ;
399
+ if (input.scalar_type () == at::kHalf && hidden_size <= 4096 ) {
400
+ shmem_bytes = static_cast <size_t >(hidden_size) * sizeof (at::Half);
401
+ smem_elems = hidden_size; // flag to kernel that shmem was provisioned
402
+ }
403
+
161
404
VLLM_DISPATCH_FLOATING_TYPES (input.scalar_type (), " rms_norm_kernel" , [&] {
162
- vllm::rms_norm_kernel<scalar_t ><<<grid, block, 0 , stream>>> (
163
- out.data_ptr <scalar_t >(), input.data_ptr <scalar_t >(), input_stride,
164
- weight.data_ptr <scalar_t >(), epsilon, num_tokens, hidden_size);
405
+ vllm::rms_norm_kernel<scalar_t >
406
+ <<<grid, block, shmem_bytes, stream>>> (
407
+ out.data_ptr <scalar_t >(),
408
+ input.data_ptr <scalar_t >(),
409
+ in_stride,
410
+ weight.data_ptr <scalar_t >(),
411
+ static_cast <float >(epsilon),
412
+ num_tokens,
413
+ hidden_size,
414
+ smem_elems);
165
415
});
166
416
}
167
417
418
+
168
419
#define LAUNCH_FUSED_ADD_RMS_NORM (width ) \
169
420
VLLM_DISPATCH_FLOATING_TYPES ( \
170
421
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
0 commit comments