@@ -19,15 +19,15 @@ __device__ __forceinline__ T warp_sum(T v) {
19
19
#ifdef __HIP_PLATFORM_AMD__
20
20
const unsigned long long m = 0xffffffffffffffffull ; // HIP needs 64-bit mask
21
21
#else
22
- const unsigned m = 0xffffffffu ; // CUDA 32-bit mask
22
+ const unsigned m = 0xffffffffu ; // CUDA 32-bit mask
23
23
#endif
24
24
// Always reduce over 32 lanes to match downstream logic.
25
25
constexpr int kWidth = 32 ;
26
26
v += __shfl_down_sync (m, v, 16 , kWidth );
27
- v += __shfl_down_sync (m, v, 8 , kWidth );
28
- v += __shfl_down_sync (m, v, 4 , kWidth );
29
- v += __shfl_down_sync (m, v, 2 , kWidth );
30
- v += __shfl_down_sync (m, v, 1 , kWidth );
27
+ v += __shfl_down_sync (m, v, 8 , kWidth );
28
+ v += __shfl_down_sync (m, v, 4 , kWidth );
29
+ v += __shfl_down_sync (m, v, 2 , kWidth );
30
+ v += __shfl_down_sync (m, v, 1 , kWidth );
31
31
return v;
32
32
}
33
33
@@ -63,18 +63,17 @@ __device__ __forceinline__ void copy_row_to_shared_aligned(
63
63
for (int i = tid; i < prefix; i += blockDim .x ) dst[i] = src[i];
64
64
65
65
// vector main
66
- const int remain = n_elems - prefix;
66
+ const int remain = n_elems - prefix;
67
67
const int main_elems = (remain / perVec) * perVec;
68
68
if (main_elems > 0 ) {
69
69
const uint4 * __restrict__ vsrc =
70
70
reinterpret_cast <const uint4 *>(src + prefix);
71
71
72
72
#if defined(__HIP_PLATFORM_AMD__)
73
73
// ROCm: vector load from global, scalar 32-bit stores to shared
74
- uint32_t * __restrict__ s32 =
75
- reinterpret_cast <uint32_t *>(dst + prefix);
76
- const int nvec = main_elems / perVec; // 16B packets
77
- constexpr int WORDS_PER_PKT = 16 / sizeof (uint32_t ); // = 4
74
+ uint32_t * __restrict__ s32 = reinterpret_cast <uint32_t *>(dst + prefix);
75
+ const int nvec = main_elems / perVec; // 16B packets
76
+ constexpr int WORDS_PER_PKT = 16 / sizeof (uint32_t ); // = 4
78
77
for (int v = tid; v < nvec; v += blockDim .x ) {
79
78
uint4 p = vsrc[v];
80
79
const int base = v * WORDS_PER_PKT;
@@ -85,8 +84,7 @@ __device__ __forceinline__ void copy_row_to_shared_aligned(
85
84
}
86
85
#else
87
86
// CUDA: vector load + vector store (fastest)
88
- uint4 * __restrict__ vdst =
89
- reinterpret_cast <uint4 *>(dst + prefix);
87
+ uint4 * __restrict__ vdst = reinterpret_cast <uint4 *>(dst + prefix);
90
88
const int nvec = main_elems / perVec;
91
89
for (int v = tid; v < nvec; v += blockDim .x ) {
92
90
uint4 p = vsrc[v];
@@ -101,12 +99,14 @@ __device__ __forceinline__ void copy_row_to_shared_aligned(
101
99
__syncthreads ();
102
100
}
103
101
104
- // ---------------- vec/scalar ops (generic, used for all dtypes) ----------------
102
+ // ---------------- vec/scalar ops (generic, used for all dtypes)
103
+ // ----------------
105
104
template <int V, typename T>
106
105
struct VecMulNormWeight {
107
- const vec_n_t <T, V>* __restrict__ wv; // vector view of weight (aligned with in/out)
106
+ const vec_n_t <T, V>* __restrict__ wv; // vector view of weight (aligned with
107
+ // in/out)
108
108
float inv_rms;
109
- int stride_vec;
109
+ int stride_vec;
110
110
mutable int64_t vec_idx;
111
111
112
112
__device__ __forceinline__ void operator ()(vec_n_t <T, V>& dst,
@@ -123,8 +123,8 @@ struct VecMulNormWeight {
123
123
124
124
template <typename T>
125
125
struct ScalarMulNormWeight {
126
- const T* __restrict__ w_base; // already offset by +prefix
127
- T* __restrict__ out_base; // out_row + prefix
126
+ const T* __restrict__ w_base; // already offset by +prefix
127
+ T* __restrict__ out_base; // out_row + prefix
128
128
float inv_rms;
129
129
__device__ __forceinline__ void operator ()(T& dst, const T src) const {
130
130
const int i = static_cast <int >(&dst - out_base);
@@ -139,12 +139,11 @@ __global__ void rms_norm_kernel(
139
139
scalar_t * __restrict__ out, // [..., hidden_size]
140
140
const scalar_t * __restrict__ input, // [..., hidden_size]
141
141
const int64_t input_stride,
142
- const scalar_t * __restrict__ weight, // [hidden_size]
142
+ const scalar_t * __restrict__ weight, // [hidden_size]
143
143
const float epsilon, const int /* num_tokens*/ , const int hidden_size,
144
144
int smem_elems) {
145
-
146
- const scalar_t * __restrict__ in_row = input + blockIdx .x * input_stride;
147
- scalar_t * __restrict__ out_row = out + blockIdx .x * hidden_size;
145
+ const scalar_t * __restrict__ in_row = input + blockIdx .x * input_stride;
146
+ scalar_t * __restrict__ out_row = out + blockIdx .x * hidden_size;
148
147
149
148
// Optional cached-row (half) when host provisioned shmem
150
149
extern __shared__ unsigned char smem_raw[];
@@ -187,15 +186,17 @@ __global__ void rms_norm_kernel(
187
186
188
187
acc_t total = acc_t (0 );
189
188
if (threadIdx .x < 32 ) {
190
- acc_t v = (threadIdx .x < (blockDim .x + 31 ) / 32 ) ? warp_sums_sh[threadIdx .x ] : acc_t (0 );
189
+ acc_t v = (threadIdx .x < (blockDim .x + 31 ) / 32 ) ? warp_sums_sh[threadIdx .x ]
190
+ : acc_t (0 );
191
191
total = warp_sum<acc_t >(v);
192
192
if (threadIdx .x == 0 ) warp_sums_sh[0 ] = total;
193
193
}
194
194
__syncthreads ();
195
195
196
196
// compute inv_rms in float to match baseline epsilon semantics
197
- const float inv_rms =
198
- rsqrtf (static_cast <float >(warp_sums_sh[0 ] / static_cast <acc_t >(hidden_size)) + epsilon);
197
+ const float inv_rms = rsqrtf (
198
+ static_cast <float >(warp_sums_sh[0 ] / static_cast <acc_t >(hidden_size)) +
199
+ epsilon);
199
200
200
201
// -------- Fast path: HS == blockDim.x (e.g., 1024) --------
201
202
if (hidden_size == blockDim .x ) {
@@ -210,10 +211,9 @@ __global__ void rms_norm_kernel(
210
211
constexpr int V = (sizeof (scalar_t ) == 2 ) ? 8 : 4 ; // 16B packets
211
212
constexpr int WIDTH = V * sizeof (scalar_t );
212
213
213
- const bool can_vec =
214
- (hidden_size % V == 0 ) &&
215
- same_phase (in_row, out_row, WIDTH) &&
216
- same_phase (in_row, weight, WIDTH);
214
+ const bool can_vec = (hidden_size % V == 0 ) &&
215
+ same_phase (in_row, out_row, WIDTH) &&
216
+ same_phase (in_row, weight, WIDTH);
217
217
218
218
if (can_vec) {
219
219
const uintptr_t addr = reinterpret_cast <uintptr_t >(in_row);
@@ -227,29 +227,24 @@ __global__ void rms_norm_kernel(
227
227
}
228
228
229
229
// vector main
230
- const int remain = hidden_size - prefix;
230
+ const int remain = hidden_size - prefix;
231
231
const int main_len = (remain / V) * V;
232
232
if (main_len > 0 ) {
233
233
using VecT = vec_n_t <scalar_t , V>;
234
234
const VecT* __restrict__ wv =
235
235
reinterpret_cast <const VecT*>(weight + prefix);
236
236
237
- VecMulNormWeight<V, scalar_t > vec_op{
238
- /* wv=*/ wv,
239
- /* inv_rms=*/ inv_rms,
240
- /* stride_vec=*/ (int )blockDim .x ,
241
- /* vec_idx=*/ (int64_t )threadIdx .x
242
- };
243
- ScalarMulNormWeight<scalar_t > sca_op{
244
- /* w_base=*/ weight + prefix,
245
- /* out_base=*/ out_row + prefix,
246
- /* inv_rms=*/ inv_rms
247
- };
237
+ VecMulNormWeight<V, scalar_t > vec_op{/* wv=*/ wv,
238
+ /* inv_rms=*/ inv_rms,
239
+ /* stride_vec=*/ (int )blockDim .x ,
240
+ /* vec_idx=*/ (int64_t )threadIdx .x };
241
+ ScalarMulNormWeight<scalar_t > sca_op{/* w_base=*/ weight + prefix,
242
+ /* out_base=*/ out_row + prefix,
243
+ /* inv_rms=*/ inv_rms};
248
244
249
245
const scalar_t * vin = use_cached ? (s_in + prefix) : (in_row + prefix);
250
- vectorize_with_alignment<V>(
251
- vin, out_row + prefix, main_len,
252
- threadIdx .x , blockDim .x , vec_op, sca_op);
246
+ vectorize_with_alignment<V>(vin, out_row + prefix, main_len, threadIdx .x ,
247
+ blockDim .x , vec_op, sca_op);
253
248
}
254
249
255
250
// scalar tail
@@ -269,7 +264,6 @@ __global__ void rms_norm_kernel(
269
264
}
270
265
}
271
266
272
-
273
267
/* Function specialization in the case of FP16/BF16 tensors.
274
268
Additional optimizations we can make in this case are
275
269
packed and vectorized operations, which help with the
@@ -369,9 +363,9 @@ fused_add_rms_norm_kernel(
369
363
} // namespace vllm
370
364
371
365
static inline int ln_block_threads_unified (int H) {
372
- int threads = (H >= 1024 ) ? 256
373
- : (H >= 512 ) ? 512
374
- : std::min (1024 , ((H + 31 ) / 32 ) * 32 );
366
+ int threads = (H >= 1024 ) ? 256
367
+ : (H >= 512 ) ? 512
368
+ : std::min (1024 , ((H + 31 ) / 32 ) * 32 );
375
369
return std::min (1024 , std::max (128 , ((threads + 31 ) / 32 ) * 32 ));
376
370
}
377
371
@@ -383,8 +377,8 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
383
377
TORCH_CHECK (input.stride (-1 ) == 1 );
384
378
TORCH_CHECK (weight.is_contiguous ());
385
379
386
- const int hidden_size = input.size (-1 );
387
- const int num_tokens = input.numel () / hidden_size;
380
+ const int hidden_size = input.size (-1 );
381
+ const int num_tokens = input.numel () / hidden_size;
388
382
const int64_t in_stride = input.stride (-2 );
389
383
390
384
dim3 grid (num_tokens);
@@ -393,29 +387,23 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
393
387
const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
394
388
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
395
389
396
- // Optional cached-row for FP16 (recommended). Kernel still works if this is 0.
390
+ // Optional cached-row for FP16 (recommended). Kernel still works if this is
391
+ // 0.
397
392
size_t shmem_bytes = 0 ;
398
393
int smem_elems = 0 ;
399
394
if (input.scalar_type () == at::kHalf && hidden_size <= 4096 ) {
400
395
shmem_bytes = static_cast <size_t >(hidden_size) * sizeof (at::Half);
401
- smem_elems = hidden_size; // flag to kernel that shmem was provisioned
396
+ smem_elems = hidden_size; // flag to kernel that shmem was provisioned
402
397
}
403
398
404
399
VLLM_DISPATCH_FLOATING_TYPES (input.scalar_type (), " rms_norm_kernel" , [&] {
405
- vllm::rms_norm_kernel<scalar_t >
406
- <<<grid, block, shmem_bytes, stream>>> (
407
- out.data_ptr <scalar_t >(),
408
- input.data_ptr <scalar_t >(),
409
- in_stride,
410
- weight.data_ptr <scalar_t >(),
411
- static_cast <float >(epsilon),
412
- num_tokens,
413
- hidden_size,
414
- smem_elems);
400
+ vllm::rms_norm_kernel<scalar_t ><<<grid, block, shmem_bytes, stream>>> (
401
+ out.data_ptr <scalar_t >(), input.data_ptr <scalar_t >(), in_stride,
402
+ weight.data_ptr <scalar_t >(), static_cast <float >(epsilon), num_tokens,
403
+ hidden_size, smem_elems);
415
404
});
416
405
}
417
406
418
-
419
407
#define LAUNCH_FUSED_ADD_RMS_NORM (width ) \
420
408
VLLM_DISPATCH_FLOATING_TYPES ( \
421
409
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
0 commit comments