-
-
Notifications
You must be signed in to change notification settings - Fork 9.3k
Vectorize RMSNorm CUDA kernel #22602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
bbeckca
wants to merge
2
commits into
vllm-project:main
Choose a base branch
from
bbeckca:export-D79969610
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+351
−62
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
|
||
#include <torch/cuda.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
#include "quantization/vectorization_utils.cuh" | ||
|
||
#ifndef USE_ROCM | ||
#include <cub/cub.cuh> | ||
|
@@ -12,35 +13,254 @@ | |
|
||
namespace vllm { | ||
|
||
// TODO(woosuk): Further optimize this kernel. | ||
// ---------- helpers ---------- | ||
template <typename T> | ||
__device__ __forceinline__ T warp_sum(T v) { | ||
#ifdef __HIP_PLATFORM_AMD__ | ||
const unsigned long long m = 0xffffffffffffffffull; // HIP needs 64-bit mask | ||
#else | ||
const unsigned m = 0xffffffffu; // CUDA 32-bit mask | ||
#endif | ||
// Always reduce over 32 lanes to match downstream logic. | ||
constexpr int kWidth = 32; | ||
v += __shfl_down_sync(m, v, 16, kWidth); | ||
v += __shfl_down_sync(m, v, 8, kWidth); | ||
v += __shfl_down_sync(m, v, 4, kWidth); | ||
v += __shfl_down_sync(m, v, 2, kWidth); | ||
v += __shfl_down_sync(m, v, 1, kWidth); | ||
return v; | ||
} | ||
|
||
template <typename T> | ||
__device__ __forceinline__ bool same_phase(const T* a, const T* b, int widthB) { | ||
auto ai = reinterpret_cast<uintptr_t>(a); | ||
auto bi = reinterpret_cast<uintptr_t>(b); | ||
return ((ai ^ bi) & (widthB - 1)) == 0; | ||
} | ||
|
||
// Safe 16B copy to shared: prefix to align, vector main, scalar tail. | ||
template <typename T> | ||
__device__ __forceinline__ void copy_row_to_shared_aligned( | ||
const T* __restrict__ src, T* __restrict__ dst, int n_elems, int tid) { | ||
const uintptr_t sa = reinterpret_cast<uintptr_t>(src); | ||
const uintptr_t da = reinterpret_cast<uintptr_t>(dst); | ||
const bool same16 = (((sa ^ da) & (16 - 1)) == 0); | ||
|
||
if (!same16) { | ||
for (int i = tid; i < n_elems; i += blockDim.x) dst[i] = src[i]; | ||
__syncthreads(); | ||
return; | ||
} | ||
const int ebytes = sizeof(T); | ||
const int perVec = 16 / ebytes; | ||
|
||
int prefix = 0; | ||
const int mis = sa & (16 - 1); | ||
if (mis) prefix = (16 - mis) / ebytes; | ||
if (prefix > n_elems) prefix = n_elems; | ||
|
||
// scalar prefix | ||
for (int i = tid; i < prefix; i += blockDim.x) dst[i] = src[i]; | ||
|
||
// vector main | ||
const int remain = n_elems - prefix; | ||
const int main_elems = (remain / perVec) * perVec; | ||
if (main_elems > 0) { | ||
const uint4* __restrict__ vsrc = | ||
reinterpret_cast<const uint4*>(src + prefix); | ||
|
||
#if defined(__HIP_PLATFORM_AMD__) | ||
// ROCm: vector load from global, scalar 32-bit stores to shared | ||
uint32_t* __restrict__ s32 = reinterpret_cast<uint32_t*>(dst + prefix); | ||
const int nvec = main_elems / perVec; // 16B packets | ||
constexpr int WORDS_PER_PKT = 16 / sizeof(uint32_t); // = 4 | ||
for (int v = tid; v < nvec; v += blockDim.x) { | ||
uint4 p = vsrc[v]; | ||
const int base = v * WORDS_PER_PKT; | ||
s32[base + 0] = p.x; | ||
s32[base + 1] = p.y; | ||
s32[base + 2] = p.z; | ||
s32[base + 3] = p.w; | ||
} | ||
#else | ||
// CUDA: vector load + vector store (fastest) | ||
uint4* __restrict__ vdst = reinterpret_cast<uint4*>(dst + prefix); | ||
const int nvec = main_elems / perVec; | ||
for (int v = tid; v < nvec; v += blockDim.x) { | ||
uint4 p = vsrc[v]; | ||
vdst[v] = p; | ||
} | ||
#endif | ||
} | ||
|
||
// scalar tail | ||
const int tail = prefix + main_elems; | ||
for (int i = tid + tail; i < n_elems; i += blockDim.x) dst[i] = src[i]; | ||
__syncthreads(); | ||
} | ||
|
||
// ---------------- vec/scalar ops (generic, used for all dtypes) | ||
// ---------------- | ||
template <int V, typename T> | ||
struct VecMulNormWeight { | ||
const vec_n_t<T, V>* __restrict__ wv; // vector view of weight (aligned with | ||
// in/out) | ||
float inv_rms; | ||
int stride_vec; | ||
mutable int64_t vec_idx; | ||
|
||
__device__ __forceinline__ void operator()(vec_n_t<T, V>& dst, | ||
const vec_n_t<T, V>& src) const { | ||
const vec_n_t<T, V> w = wv[vec_idx]; | ||
#pragma unroll | ||
for (int j = 0; j < V; ++j) { | ||
T xn = static_cast<T>(static_cast<float>(src.val[j]) * inv_rms); | ||
dst.val[j] = xn * w.val[j]; | ||
} | ||
vec_idx += stride_vec; | ||
} | ||
}; | ||
|
||
template <typename T> | ||
struct ScalarMulNormWeight { | ||
const T* __restrict__ w_base; // already offset by +prefix | ||
T* __restrict__ out_base; // out_row + prefix | ||
float inv_rms; | ||
__device__ __forceinline__ void operator()(T& dst, const T src) const { | ||
const int i = static_cast<int>(&dst - out_base); | ||
T xn = static_cast<T>(static_cast<float>(src) * inv_rms); | ||
dst = xn * w_base[i]; | ||
} | ||
}; | ||
|
||
// ---------- kernel ---------- | ||
template <typename scalar_t> | ||
__global__ void rms_norm_kernel( | ||
scalar_t* __restrict__ out, // [..., hidden_size] | ||
const scalar_t* __restrict__ input, // [..., hidden_size] | ||
const int64_t input_stride, | ||
const scalar_t* __restrict__ weight, // [hidden_size] | ||
const float epsilon, const int num_tokens, const int hidden_size) { | ||
__shared__ float s_variance; | ||
float variance = 0.0f; | ||
const float epsilon, const int /*num_tokens*/, const int hidden_size, | ||
int smem_elems) { | ||
const scalar_t* __restrict__ in_row = input + blockIdx.x * input_stride; | ||
scalar_t* __restrict__ out_row = out + blockIdx.x * hidden_size; | ||
|
||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { | ||
const float x = (float)input[blockIdx.x * input_stride + idx]; | ||
variance += x * x; | ||
// Optional cached-row (half) when host provisioned shmem | ||
extern __shared__ unsigned char smem_raw[]; | ||
scalar_t* s_in = reinterpret_cast<scalar_t*>(smem_raw); | ||
|
||
#ifdef __HIP_PLATFORM_AMD__ | ||
constexpr bool kAllowCache = false; | ||
#else | ||
constexpr bool kAllowCache = true; | ||
#endif | ||
const bool use_cached = | ||
kAllowCache && (sizeof(scalar_t) == 2) && (smem_elems > 0); | ||
|
||
#if !defined(__HIP_PLATFORM_AMD__) | ||
if (use_cached) { | ||
copy_row_to_shared_aligned(in_row, s_in, hidden_size, threadIdx.x); | ||
} | ||
#endif | ||
|
||
using BlockReduce = cub::BlockReduce<float, 1024>; | ||
__shared__ typename BlockReduce::TempStorage reduceStore; | ||
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); | ||
// -------- Pass 1: sum of squares -------- | ||
using acc_t = float; | ||
acc_t sumsq = acc_t(0); | ||
if (use_cached) { | ||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { | ||
acc_t x = static_cast<acc_t>(s_in[i]); | ||
sumsq += x * x; | ||
} | ||
} else { | ||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { | ||
acc_t x = static_cast<acc_t>(in_row[i]); | ||
sumsq += x * x; | ||
} | ||
} | ||
|
||
if (threadIdx.x == 0) { | ||
s_variance = rsqrtf(variance / hidden_size + epsilon); | ||
// warp + block reduction in acc_t | ||
acc_t wsum = warp_sum<acc_t>(sumsq); | ||
__shared__ acc_t warp_sums_sh[32]; | ||
if ((threadIdx.x & 31) == 0) warp_sums_sh[threadIdx.x >> 5] = wsum; | ||
__syncthreads(); | ||
|
||
acc_t total = acc_t(0); | ||
if (threadIdx.x < 32) { | ||
acc_t v = (threadIdx.x < (blockDim.x + 31) / 32) ? warp_sums_sh[threadIdx.x] | ||
: acc_t(0); | ||
total = warp_sum<acc_t>(v); | ||
if (threadIdx.x == 0) warp_sums_sh[0] = total; | ||
} | ||
__syncthreads(); | ||
|
||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { | ||
float x = (float)input[blockIdx.x * input_stride + idx]; | ||
out[blockIdx.x * hidden_size + idx] = | ||
((scalar_t)(x * s_variance)) * weight[idx]; | ||
// compute inv_rms in float to match baseline epsilon semantics | ||
const float inv_rms = rsqrtf( | ||
static_cast<float>(warp_sums_sh[0] / static_cast<acc_t>(hidden_size)) + | ||
epsilon); | ||
|
||
// -------- Fast path: HS == blockDim.x (e.g., 1024) -------- | ||
if (hidden_size == blockDim.x) { | ||
int i = threadIdx.x; | ||
float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]); | ||
scalar_t xn = static_cast<scalar_t>(x * inv_rms); | ||
out_row[i] = xn * weight[i]; | ||
return; | ||
} | ||
|
||
// -------- Pass 2: Vectorize when phases align -------- | ||
constexpr int V = (sizeof(scalar_t) == 2) ? 8 : 4; // 16B packets | ||
constexpr int WIDTH = V * sizeof(scalar_t); | ||
|
||
const bool can_vec = (hidden_size % V == 0) && | ||
same_phase(in_row, out_row, WIDTH) && | ||
same_phase(in_row, weight, WIDTH); | ||
|
||
if (can_vec) { | ||
const uintptr_t addr = reinterpret_cast<uintptr_t>(in_row); | ||
const int mis = addr & (WIDTH - 1); | ||
const int prefix = mis ? (WIDTH - mis) / (int)sizeof(scalar_t) : 0; | ||
|
||
// scalar prefix | ||
for (int i = threadIdx.x; i < prefix; i += blockDim.x) { | ||
float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]); | ||
out_row[i] = static_cast<scalar_t>(x * inv_rms) * weight[i]; | ||
} | ||
Comment on lines
+210
to
+227
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have the duplicate logic in |
||
|
||
// vector main | ||
const int remain = hidden_size - prefix; | ||
const int main_len = (remain / V) * V; | ||
if (main_len > 0) { | ||
using VecT = vec_n_t<scalar_t, V>; | ||
const VecT* __restrict__ wv = | ||
reinterpret_cast<const VecT*>(weight + prefix); | ||
|
||
VecMulNormWeight<V, scalar_t> vec_op{/*wv=*/wv, | ||
/*inv_rms=*/inv_rms, | ||
/*stride_vec=*/(int)blockDim.x, | ||
/*vec_idx=*/(int64_t)threadIdx.x}; | ||
ScalarMulNormWeight<scalar_t> sca_op{/*w_base=*/weight + prefix, | ||
/*out_base=*/out_row + prefix, | ||
/*inv_rms=*/inv_rms}; | ||
|
||
const scalar_t* vin = use_cached ? (s_in + prefix) : (in_row + prefix); | ||
vectorize_with_alignment<V>(vin, out_row + prefix, main_len, threadIdx.x, | ||
blockDim.x, vec_op, sca_op); | ||
} | ||
|
||
// scalar tail | ||
const int tail = prefix + ((remain / V) * V); | ||
for (int i = threadIdx.x + tail; i < hidden_size; i += blockDim.x) { | ||
float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]); | ||
out_row[i] = static_cast<scalar_t>(x * inv_rms) * weight[i]; | ||
} | ||
return; | ||
} | ||
|
||
// -------- Fallback scalar -------- | ||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { | ||
float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]); | ||
scalar_t xn = static_cast<scalar_t>(x * inv_rms); | ||
out_row[i] = xn * weight[i]; | ||
Comment on lines
+259
to
+263
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same above |
||
} | ||
} | ||
|
||
|
@@ -142,6 +362,13 @@ fused_add_rms_norm_kernel( | |
|
||
} // namespace vllm | ||
|
||
static inline int ln_block_threads_unified(int H) { | ||
int threads = (H >= 1024) ? 256 | ||
: (H >= 512) ? 512 | ||
: std::min(1024, ((H + 31) / 32) * 32); | ||
return std::min(1024, std::max(128, ((threads + 31) / 32) * 32)); | ||
} | ||
|
||
void rms_norm(torch::Tensor& out, // [..., hidden_size] | ||
torch::Tensor& input, // [..., hidden_size] | ||
torch::Tensor& weight, // [hidden_size] | ||
|
@@ -150,18 +377,30 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] | |
TORCH_CHECK(input.stride(-1) == 1); | ||
TORCH_CHECK(weight.is_contiguous()); | ||
|
||
int hidden_size = input.size(-1); | ||
int num_tokens = input.numel() / hidden_size; | ||
int64_t input_stride = input.stride(-2); | ||
const int hidden_size = input.size(-1); | ||
const int num_tokens = input.numel() / hidden_size; | ||
const int64_t in_stride = input.stride(-2); | ||
|
||
dim3 grid(num_tokens); | ||
dim3 block(std::min(hidden_size, 1024)); | ||
dim3 block(ln_block_threads_unified(hidden_size)); | ||
|
||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
// Optional cached-row for FP16 (recommended). Kernel still works if this is | ||
// 0. | ||
size_t shmem_bytes = 0; | ||
int smem_elems = 0; | ||
if (input.scalar_type() == at::kHalf && hidden_size <= 4096) { | ||
shmem_bytes = static_cast<size_t>(hidden_size) * sizeof(at::Half); | ||
smem_elems = hidden_size; // flag to kernel that shmem was provisioned | ||
} | ||
|
||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { | ||
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride, | ||
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size); | ||
vllm::rms_norm_kernel<scalar_t><<<grid, block, shmem_bytes, stream>>>( | ||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), in_stride, | ||
weight.data_ptr<scalar_t>(), static_cast<float>(epsilon), num_tokens, | ||
hidden_size, smem_elems); | ||
}); | ||
} | ||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note:
warp_sum
andln_block_threads_unified
are the same for rms_norm_kernel and rms_norm_static_fp8_quant. Will move to shared helper after aligning on high level approach.