Skip to content

Commit 6942388

Browse files
bbeckcafacebook-github-bot
authored andcommitted
Vectorize RMSNorm CUDA kernel
Summary: What: Make RMSNorm faster by reading data in bigger aligned chunks, caching fp16 rows in shared memory, making the FP8-quant version work with strided inputs, and using the same launch settings as the unfused path. Why: Cut global memory traffic using aligned vector inputs/outputs and shared-mem reuse (avoids second read), make the FP8 path safe for strided inputs, and preserve numerics by matching the unfused reduction/launch order. Test Plan: 1) Run tests ``` [[email protected] /data/users/benjibeck/fbsource/fbcode/vllm (1043a27694)]$ buck2 test :test_kernels_layernorm Buck UI: https://www.internalfb.com/buck2/054ebad3-ad92-4676-a4d2-3bf43e44f31a Test UI: https://www.internalfb.com/intern/testinfra/testrun/10414574240710255 Network: Up: 152MiB Down: 2.9GiB (reSessionID-14af330c-26bf-41d5-87b0-5775bf7d6f8a) Loading targets. Remaining 0/7 150 dirs read, 69 targets declared Analyzing targets. Remaining 0/32 772 actions, 819 artifacts declared Executing actions. Remaining 0/245 48.3s exec time total Command: test. Finished 1 local, 14 remote, 131 cache (90% hit) 45.2s exec time cached (93%) Time elapsed: 4:53.4s Tests finished: Pass 3169. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` 2) Run benchmark ``` buck run :benchmark_layernorm -- --num-tokens 16384 --hidden-size 1024 --dtype half --num-iters 500 Before -> Kernel running time: 105.918 us After -> Kernel running time: 42.571 us ``` Rollback Plan: Differential Revision: D79969610
1 parent 68b254d commit 6942388

File tree

3 files changed

+378
-73
lines changed

3 files changed

+378
-73
lines changed

benchmarks/kernels/benchmark_layernorm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
@torch.inference_mode()
14-
def main(
14+
def run_benchmark(
1515
num_tokens: int,
1616
hidden_size: int,
1717
add_residual: bool,
@@ -59,7 +59,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
5959
print(f"Kernel running time: {latency * 1000000:.3f} us")
6060

6161

62-
if __name__ == "__main__":
62+
def main():
63+
"""Main function for Buck compatibility."""
6364
parser = FlexibleArgumentParser(description="Benchmark the layernorm kernel.")
6465
parser.add_argument("--num-tokens", type=int, default=4096)
6566
parser.add_argument("--hidden-size", type=int, default=8192)
@@ -81,7 +82,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
8182
args = parser.parse_args()
8283
print(args)
8384

84-
main(
85+
run_benchmark(
8586
num_tokens=args.num_tokens,
8687
hidden_size=args.hidden_size,
8788
add_residual=args.add_residual,
@@ -91,3 +92,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
9192
num_warmup_iters=args.num_warmup_iters,
9293
num_iters=args.num_iters,
9394
)
95+
96+
97+
if __name__ == "__main__":
98+
main()

csrc/layernorm_kernels.cu

Lines changed: 275 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <torch/cuda.h>
55
#include <c10/cuda/CUDAGuard.h>
6+
#include "quantization/vectorization_utils.cuh"
67

78
#ifndef USE_ROCM
89
#include <cub/cub.cuh>
@@ -12,38 +13,263 @@
1213

1314
namespace vllm {
1415

15-
// TODO(woosuk): Further optimize this kernel.
16+
// ---------- helpers ----------
17+
template <typename T>
18+
__device__ __forceinline__ T warp_sum(T v) {
19+
#ifdef __HIP_PLATFORM_AMD__
20+
const unsigned long long m = 0xffffffffffffffffull; // HIP needs 64-bit mask
21+
#else
22+
const unsigned m = 0xffffffffu; // CUDA 32-bit mask
23+
#endif
24+
// Always reduce over 32 lanes to match downstream logic.
25+
constexpr int kWidth = 32;
26+
v += __shfl_down_sync(m, v, 16, kWidth);
27+
v += __shfl_down_sync(m, v, 8, kWidth);
28+
v += __shfl_down_sync(m, v, 4, kWidth);
29+
v += __shfl_down_sync(m, v, 2, kWidth);
30+
v += __shfl_down_sync(m, v, 1, kWidth);
31+
return v;
32+
}
33+
34+
template <typename T>
35+
__device__ __forceinline__ bool same_phase(const T* a, const T* b, int widthB) {
36+
auto ai = reinterpret_cast<uintptr_t>(a);
37+
auto bi = reinterpret_cast<uintptr_t>(b);
38+
return ((ai ^ bi) & (widthB - 1)) == 0;
39+
}
40+
41+
// Safe 16B copy to shared: prefix to align, vector main, scalar tail.
42+
template <typename T>
43+
__device__ __forceinline__ void copy_row_to_shared_aligned(
44+
const T* __restrict__ src, T* __restrict__ dst, int n_elems, int tid) {
45+
const uintptr_t sa = reinterpret_cast<uintptr_t>(src);
46+
const uintptr_t da = reinterpret_cast<uintptr_t>(dst);
47+
const bool same16 = (((sa ^ da) & (16 - 1)) == 0);
48+
49+
if (!same16) {
50+
for (int i = tid; i < n_elems; i += blockDim.x) dst[i] = src[i];
51+
__syncthreads();
52+
return;
53+
}
54+
const int ebytes = sizeof(T);
55+
const int perVec = 16 / ebytes;
56+
57+
int prefix = 0;
58+
const int mis = sa & (16 - 1);
59+
if (mis) prefix = (16 - mis) / ebytes;
60+
if (prefix > n_elems) prefix = n_elems;
61+
62+
// scalar prefix
63+
for (int i = tid; i < prefix; i += blockDim.x) dst[i] = src[i];
64+
65+
// vector main
66+
const int remain = n_elems - prefix;
67+
const int main_elems = (remain / perVec) * perVec;
68+
if (main_elems > 0) {
69+
const uint4* __restrict__ vsrc =
70+
reinterpret_cast<const uint4*>(src + prefix);
71+
72+
#if defined(__HIP_PLATFORM_AMD__)
73+
// ROCm: vector load from global, scalar 32-bit stores to shared
74+
uint32_t* __restrict__ s32 =
75+
reinterpret_cast<uint32_t*>(dst + prefix);
76+
const int nvec = main_elems / perVec; // 16B packets
77+
constexpr int WORDS_PER_PKT = 16 / sizeof(uint32_t); // = 4
78+
for (int v = tid; v < nvec; v += blockDim.x) {
79+
uint4 p = vsrc[v];
80+
const int base = v * WORDS_PER_PKT;
81+
s32[base + 0] = p.x;
82+
s32[base + 1] = p.y;
83+
s32[base + 2] = p.z;
84+
s32[base + 3] = p.w;
85+
}
86+
#else
87+
// CUDA: vector load + vector store (fastest)
88+
uint4* __restrict__ vdst =
89+
reinterpret_cast<uint4*>(dst + prefix);
90+
const int nvec = main_elems / perVec;
91+
for (int v = tid; v < nvec; v += blockDim.x) {
92+
uint4 p = vsrc[v];
93+
vdst[v] = p;
94+
}
95+
#endif
96+
}
97+
98+
// scalar tail
99+
const int tail = prefix + main_elems;
100+
for (int i = tid + tail; i < n_elems; i += blockDim.x) dst[i] = src[i];
101+
__syncthreads();
102+
}
103+
104+
// ---------------- vec/scalar ops (generic, used for all dtypes) ----------------
105+
template <int V, typename T>
106+
struct VecMulNormWeight {
107+
const vec_n_t<T, V>* __restrict__ wv; // vector view of weight (aligned with in/out)
108+
float inv_rms;
109+
int stride_vec;
110+
mutable int64_t vec_idx;
111+
112+
__device__ __forceinline__ void operator()(vec_n_t<T, V>& dst,
113+
const vec_n_t<T, V>& src) const {
114+
const vec_n_t<T, V> w = wv[vec_idx];
115+
#pragma unroll
116+
for (int j = 0; j < V; ++j) {
117+
T xn = static_cast<T>(static_cast<float>(src.val[j]) * inv_rms);
118+
dst.val[j] = xn * w.val[j];
119+
}
120+
vec_idx += stride_vec;
121+
}
122+
};
123+
124+
template <typename T>
125+
struct ScalarMulNormWeight {
126+
const T* __restrict__ w_base; // already offset by +prefix
127+
T* __restrict__ out_base; // out_row + prefix
128+
float inv_rms;
129+
__device__ __forceinline__ void operator()(T& dst, const T src) const {
130+
const int i = static_cast<int>(&dst - out_base);
131+
T xn = static_cast<T>(static_cast<float>(src) * inv_rms);
132+
dst = xn * w_base[i];
133+
}
134+
};
135+
136+
// ---------- kernel ----------
16137
template <typename scalar_t>
17138
__global__ void rms_norm_kernel(
18139
scalar_t* __restrict__ out, // [..., hidden_size]
19140
const scalar_t* __restrict__ input, // [..., hidden_size]
20141
const int64_t input_stride,
21-
const scalar_t* __restrict__ weight, // [hidden_size]
22-
const float epsilon, const int num_tokens, const int hidden_size) {
23-
__shared__ float s_variance;
24-
float variance = 0.0f;
142+
const scalar_t* __restrict__ weight, // [hidden_size]
143+
const float epsilon, const int /*num_tokens*/, const int hidden_size,
144+
int smem_elems) {
25145

26-
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
27-
const float x = (float)input[blockIdx.x * input_stride + idx];
28-
variance += x * x;
146+
const scalar_t* __restrict__ in_row = input + blockIdx.x * input_stride;
147+
scalar_t* __restrict__ out_row = out + blockIdx.x * hidden_size;
148+
149+
// Optional cached-row (half) when host provisioned shmem
150+
extern __shared__ unsigned char smem_raw[];
151+
scalar_t* s_in = reinterpret_cast<scalar_t*>(smem_raw);
152+
153+
#ifdef __HIP_PLATFORM_AMD__
154+
constexpr bool kAllowCache = false;
155+
#else
156+
constexpr bool kAllowCache = true;
157+
#endif
158+
const bool use_cached =
159+
kAllowCache && (sizeof(scalar_t) == 2) && (smem_elems > 0);
160+
161+
#if !defined(__HIP_PLATFORM_AMD__)
162+
if (use_cached) {
163+
copy_row_to_shared_aligned(in_row, s_in, hidden_size, threadIdx.x);
29164
}
165+
#endif
30166

31-
using BlockReduce = cub::BlockReduce<float, 1024>;
32-
__shared__ typename BlockReduce::TempStorage reduceStore;
33-
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
167+
// -------- Pass 1: sum of squares --------
168+
using acc_t = float;
169+
acc_t sumsq = acc_t(0);
170+
if (use_cached) {
171+
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
172+
acc_t x = static_cast<acc_t>(s_in[i]);
173+
sumsq += x * x;
174+
}
175+
} else {
176+
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
177+
acc_t x = static_cast<acc_t>(in_row[i]);
178+
sumsq += x * x;
179+
}
180+
}
34181

35-
if (threadIdx.x == 0) {
36-
s_variance = rsqrtf(variance / hidden_size + epsilon);
182+
// warp + block reduction in acc_t
183+
acc_t wsum = warp_sum<acc_t>(sumsq);
184+
__shared__ acc_t warp_sums_sh[32];
185+
if ((threadIdx.x & 31) == 0) warp_sums_sh[threadIdx.x >> 5] = wsum;
186+
__syncthreads();
187+
188+
acc_t total = acc_t(0);
189+
if (threadIdx.x < 32) {
190+
acc_t v = (threadIdx.x < (blockDim.x + 31) / 32) ? warp_sums_sh[threadIdx.x] : acc_t(0);
191+
total = warp_sum<acc_t>(v);
192+
if (threadIdx.x == 0) warp_sums_sh[0] = total;
37193
}
38194
__syncthreads();
39195

40-
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
41-
float x = (float)input[blockIdx.x * input_stride + idx];
42-
out[blockIdx.x * hidden_size + idx] =
43-
((scalar_t)(x * s_variance)) * weight[idx];
196+
// compute inv_rms in float to match baseline epsilon semantics
197+
const float inv_rms =
198+
rsqrtf(static_cast<float>(warp_sums_sh[0] / static_cast<acc_t>(hidden_size)) + epsilon);
199+
200+
// -------- Fast path: HS == blockDim.x (e.g., 1024) --------
201+
if (hidden_size == blockDim.x) {
202+
int i = threadIdx.x;
203+
float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]);
204+
scalar_t xn = static_cast<scalar_t>(x * inv_rms);
205+
out_row[i] = xn * weight[i];
206+
return;
207+
}
208+
209+
// -------- Pass 2: Vectorize when phases align --------
210+
constexpr int V = (sizeof(scalar_t) == 2) ? 8 : 4; // 16B packets
211+
constexpr int WIDTH = V * sizeof(scalar_t);
212+
213+
const bool can_vec =
214+
(hidden_size % V == 0) &&
215+
same_phase(in_row, out_row, WIDTH) &&
216+
same_phase(in_row, weight, WIDTH);
217+
218+
if (can_vec) {
219+
const uintptr_t addr = reinterpret_cast<uintptr_t>(in_row);
220+
const int mis = addr & (WIDTH - 1);
221+
const int prefix = mis ? (WIDTH - mis) / (int)sizeof(scalar_t) : 0;
222+
223+
// scalar prefix
224+
for (int i = threadIdx.x; i < prefix; i += blockDim.x) {
225+
float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]);
226+
out_row[i] = static_cast<scalar_t>(x * inv_rms) * weight[i];
227+
}
228+
229+
// vector main
230+
const int remain = hidden_size - prefix;
231+
const int main_len = (remain / V) * V;
232+
if (main_len > 0) {
233+
using VecT = vec_n_t<scalar_t, V>;
234+
const VecT* __restrict__ wv =
235+
reinterpret_cast<const VecT*>(weight + prefix);
236+
237+
VecMulNormWeight<V, scalar_t> vec_op{
238+
/*wv=*/ wv,
239+
/*inv_rms=*/ inv_rms,
240+
/*stride_vec=*/ (int)blockDim.x,
241+
/*vec_idx=*/ (int64_t)threadIdx.x
242+
};
243+
ScalarMulNormWeight<scalar_t> sca_op{
244+
/*w_base=*/ weight + prefix,
245+
/*out_base=*/ out_row + prefix,
246+
/*inv_rms=*/ inv_rms
247+
};
248+
249+
const scalar_t* vin = use_cached ? (s_in + prefix) : (in_row + prefix);
250+
vectorize_with_alignment<V>(
251+
vin, out_row + prefix, main_len,
252+
threadIdx.x, blockDim.x, vec_op, sca_op);
253+
}
254+
255+
// scalar tail
256+
const int tail = prefix + ((remain / V) * V);
257+
for (int i = threadIdx.x + tail; i < hidden_size; i += blockDim.x) {
258+
float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]);
259+
out_row[i] = static_cast<scalar_t>(x * inv_rms) * weight[i];
260+
}
261+
return;
262+
}
263+
264+
// -------- Fallback scalar --------
265+
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
266+
float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]);
267+
scalar_t xn = static_cast<scalar_t>(x * inv_rms);
268+
out_row[i] = xn * weight[i];
44269
}
45270
}
46271

272+
47273
/* Function specialization in the case of FP16/BF16 tensors.
48274
Additional optimizations we can make in this case are
49275
packed and vectorized operations, which help with the
@@ -142,6 +368,13 @@ fused_add_rms_norm_kernel(
142368

143369
} // namespace vllm
144370

371+
static inline int ln_block_threads_unified(int H) {
372+
int threads = (H >= 1024) ? 256
373+
: (H >= 512) ? 512
374+
: std::min(1024, ((H + 31) / 32) * 32);
375+
return std::min(1024, std::max(128, ((threads + 31) / 32) * 32));
376+
}
377+
145378
void rms_norm(torch::Tensor& out, // [..., hidden_size]
146379
torch::Tensor& input, // [..., hidden_size]
147380
torch::Tensor& weight, // [hidden_size]
@@ -150,21 +383,39 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
150383
TORCH_CHECK(input.stride(-1) == 1);
151384
TORCH_CHECK(weight.is_contiguous());
152385

153-
int hidden_size = input.size(-1);
154-
int num_tokens = input.numel() / hidden_size;
155-
int64_t input_stride = input.stride(-2);
386+
const int hidden_size = input.size(-1);
387+
const int num_tokens = input.numel() / hidden_size;
388+
const int64_t in_stride = input.stride(-2);
156389

157390
dim3 grid(num_tokens);
158-
dim3 block(std::min(hidden_size, 1024));
391+
dim3 block(ln_block_threads_unified(hidden_size));
392+
159393
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
160394
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
395+
396+
// Optional cached-row for FP16 (recommended). Kernel still works if this is 0.
397+
size_t shmem_bytes = 0;
398+
int smem_elems = 0;
399+
if (input.scalar_type() == at::kHalf && hidden_size <= 4096) {
400+
shmem_bytes = static_cast<size_t>(hidden_size) * sizeof(at::Half);
401+
smem_elems = hidden_size; // flag to kernel that shmem was provisioned
402+
}
403+
161404
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
162-
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
163-
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
164-
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
405+
vllm::rms_norm_kernel<scalar_t>
406+
<<<grid, block, shmem_bytes, stream>>>(
407+
out.data_ptr<scalar_t>(),
408+
input.data_ptr<scalar_t>(),
409+
in_stride,
410+
weight.data_ptr<scalar_t>(),
411+
static_cast<float>(epsilon),
412+
num_tokens,
413+
hidden_size,
414+
smem_elems);
165415
});
166416
}
167417

418+
168419
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
169420
VLLM_DISPATCH_FLOATING_TYPES( \
170421
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \

0 commit comments

Comments
 (0)