Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions benchmarks/kernels/benchmark_layernorm.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: These changes were made for testing within Meta infra and will be deleted before landing.

Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@torch.inference_mode()
def main(
def run_benchmark(
num_tokens: int,
hidden_size: int,
add_residual: bool,
Expand Down Expand Up @@ -59,7 +59,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
print(f"Kernel running time: {latency * 1000000:.3f} us")


if __name__ == "__main__":
def main():
"""Main function for Buck compatibility."""
parser = FlexibleArgumentParser(description="Benchmark the layernorm kernel.")
parser.add_argument("--num-tokens", type=int, default=4096)
parser.add_argument("--hidden-size", type=int, default=8192)
Expand All @@ -81,7 +82,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
args = parser.parse_args()
print(args)

main(
run_benchmark(
num_tokens=args.num_tokens,
hidden_size=args.hidden_size,
add_residual=args.add_residual,
Expand All @@ -91,3 +92,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
num_warmup_iters=args.num_warmup_iters,
num_iters=args.num_iters,
)


if __name__ == "__main__":
main()
275 changes: 247 additions & 28 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include "quantization/vectorization_utils.cuh"

#ifndef USE_ROCM
#include <cub/cub.cuh>
Expand All @@ -12,35 +13,232 @@

namespace vllm {

// TODO(woosuk): Further optimize this kernel.
constexpr int kVecBytes = 16; // 128-bit phase

template <typename T>
__device__ __forceinline__ T warp_sum(T v) {
#ifdef __HIP_PLATFORM_AMD__
const unsigned long long m = 0xffffffffffffffffull;
#else
const unsigned m = 0xffffffffu;
#endif
constexpr int kWidth = 32;
v += __shfl_down_sync(m, v, 16, kWidth);
v += __shfl_down_sync(m, v, 8, kWidth);
v += __shfl_down_sync(m, v, 4, kWidth);
v += __shfl_down_sync(m, v, 2, kWidth);
v += __shfl_down_sync(m, v, 1, kWidth);
return v;
}

template <typename T>
__device__ __forceinline__ bool same_phase(const T* a, const T* b, int bytes) {
const auto ai = reinterpret_cast<uintptr_t>(a);
const auto bi = reinterpret_cast<uintptr_t>(b);
return ((ai ^ bi) & (bytes - 1)) == 0;
}

// copy input row to shared with 16B phase when possible
template <typename T>
__device__ __forceinline__ void copy_row_to_shared_aligned(
const T* __restrict__ src, T* __restrict__ dst, int n_elems, int tid) {
const auto sa = reinterpret_cast<uintptr_t>(src);
const auto da = reinterpret_cast<uintptr_t>(dst);
const bool same = (((sa ^ da) & (kVecBytes - 1)) == 0);

if (!same) {
for (int i = tid; i < n_elems; i += blockDim.x) dst[i] = src[i];
__syncthreads();
return;
}

const int ebytes = sizeof(T);
const int perVec = kVecBytes / ebytes;

int prefix = 0;
const int mis = sa & (kVecBytes - 1);
if (mis) prefix = (kVecBytes - mis) / ebytes;
if (prefix > n_elems) prefix = n_elems;

for (int i = tid; i < prefix; i += blockDim.x) dst[i] = src[i];

const int remain = n_elems - prefix;
const int main_elems = (remain / perVec) * perVec;
if (main_elems > 0) {
const uint4* __restrict__ vsrc =
reinterpret_cast<const uint4*>(src + prefix);
#if defined(__HIP_PLATFORM_AMD__)
uint32_t* __restrict__ s32 = reinterpret_cast<uint32_t*>(dst + prefix);
const int nvec = main_elems / perVec;
constexpr int WORDS_PER_PKT = kVecBytes / sizeof(uint32_t); // 4
for (int v = tid; v < nvec; v += blockDim.x) {
const uint4 p = vsrc[v];
const int base = v * WORDS_PER_PKT;
s32[base + 0] = p.x;
s32[base + 1] = p.y;
s32[base + 2] = p.z;
s32[base + 3] = p.w;
}
#else
uint4* __restrict__ vdst = reinterpret_cast<uint4*>(dst + prefix);
const int nvec = main_elems / perVec;
for (int v = tid; v < nvec; v += blockDim.x) {
vdst[v] = vsrc[v];
}
#endif
}

const int tail = prefix + main_elems;
for (int i = tid + tail; i < n_elems; i += blockDim.x) dst[i] = src[i];
__syncthreads();
}

// functors for vectorized write
template <int V, typename T>
struct VecMulNormWeight {
const vec_n_t<T, V>* __restrict__ wv;
float inv_rms;
int stride_vec;
mutable int64_t vec_idx;
__device__ __forceinline__ void operator()(vec_n_t<T, V>& dst,
const vec_n_t<T, V>& src) const {
const vec_n_t<T, V> w = wv[vec_idx];
#pragma unroll
for (int j = 0; j < V; ++j) {
const T xn = static_cast<T>(static_cast<float>(src.val[j]) * inv_rms);
dst.val[j] = xn * w.val[j];
}
vec_idx += stride_vec;
}
};

template <typename T>
struct ScalarMulNormWeight {
const T* __restrict__ w_base;
T* __restrict__ out_base;
float inv_rms;
__device__ __forceinline__ void operator()(T& dst, const T src) const {
const int i = static_cast<int>(&dst - out_base);
const T xn = static_cast<T>(static_cast<float>(src) * inv_rms);
dst = xn * w_base[i];
}
};

template <int V, typename T>
struct VecNormMulWeightScalarW {
const T* __restrict__ w_base; // offset by prefix
float inv_rms;
int stride_vec;
mutable int vec_idx;
__device__ __forceinline__ void operator()(vec_n_t<T, V>& dst,
const vec_n_t<T, V>& src) const {
const int base = vec_idx * V;
#pragma unroll
for (int j = 0; j < V; ++j) {
const float x = static_cast<float>(src.val[j]) * inv_rms;
dst.val[j] = static_cast<T>(x * static_cast<float>(w_base[base + j]));
}
vec_idx += stride_vec;
}
};

template <typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const int64_t input_stride,
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
__global__ void rms_norm_kernel(scalar_t* __restrict__ out,
const scalar_t* __restrict__ input,
const int64_t input_stride,
const scalar_t* __restrict__ weight,
const float epsilon, const int /*num_tokens*/,
const int hidden_size, int smem_elems) {
const scalar_t* __restrict__ in_row = input + blockIdx.x * input_stride;
scalar_t* __restrict__ out_row = out + blockIdx.x * hidden_size;

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * input_stride + idx];
variance += x * x;
extern __shared__ unsigned char smem_raw[];
scalar_t* s_in = reinterpret_cast<scalar_t*>(smem_raw);

#ifdef __HIP_PLATFORM_AMD__
constexpr bool kAllowCache = false;
#else
constexpr bool kAllowCache = true;
#endif
const bool use_cached =
kAllowCache && (sizeof(scalar_t) == 2) && (smem_elems > 0);

#if !defined(__HIP_PLATFORM_AMD__)
if (use_cached)
copy_row_to_shared_aligned(in_row, s_in, hidden_size, threadIdx.x);
#endif

float sumsq = 0.f;
{
const scalar_t* base = use_cached ? s_in : in_row;
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
const float x = static_cast<float>(base[i]);
sumsq += x * x;
}
}

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
float wsum = warp_sum<float>(sumsq);
__shared__ float warp_sums_sh[32];
if ((threadIdx.x & 31) == 0) warp_sums_sh[threadIdx.x >> 5] = wsum;
__syncthreads();

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
if (threadIdx.x < 32) {
const int nwarps = (blockDim.x + 31) / 32;
const float v = (threadIdx.x < nwarps) ? warp_sums_sh[threadIdx.x] : 0.f;
const float total = warp_sum<float>(v);
if (threadIdx.x == 0) warp_sums_sh[0] = total;
}
__syncthreads();

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * input_stride + idx];
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
const float inv_rms =
rsqrtf(warp_sums_sh[0] / static_cast<float>(hidden_size) + epsilon);

if (hidden_size == blockDim.x) {
const int i = threadIdx.x;
const float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]);
const scalar_t xn = static_cast<scalar_t>(x * inv_rms);
out_row[i] = xn * weight[i];
return;
}

constexpr int V = (sizeof(scalar_t) == 2) ? 8 : 4; // 16B
constexpr int WIDTH = V * sizeof(scalar_t);
const bool vec_store_ok =
(hidden_size % V == 0) && same_phase(in_row, out_row, WIDTH);

const bool s_same = use_cached && same_phase(in_row, s_in, kVecBytes);
const scalar_t* vin = s_same ? s_in : in_row;

if (vec_store_ok) {
ScalarMulNormWeight<scalar_t> sca_op{weight, out_row, inv_rms};

const auto addr = reinterpret_cast<uintptr_t>(vin);
const int mis = addr & (WIDTH - 1);
const int prefix =
mis ? (WIDTH - mis) / static_cast<int>(sizeof(scalar_t)) : 0;

if (same_phase(in_row, weight, WIDTH)) {
using VecT = vec_n_t<scalar_t, V>;
const VecT* __restrict__ wv =
reinterpret_cast<const VecT*>(weight + prefix);
VecMulNormWeight<V, scalar_t> vec_op{wv, inv_rms, (int)blockDim.x,
(int64_t)threadIdx.x};
vectorize_with_alignment<V>(vin, out_row, hidden_size, threadIdx.x,
blockDim.x, vec_op, sca_op);
} else {
VecNormMulWeightScalarW<V, scalar_t> vec_op{
weight + prefix, inv_rms, (int)blockDim.x, (int)threadIdx.x};
vectorize_with_alignment<V>(vin, out_row, hidden_size, threadIdx.x,
blockDim.x, vec_op, sca_op);
}
return;
}

// scalar fallback (keeps op order identical to fused path)
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
const float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]);
const scalar_t xn = static_cast<scalar_t>(x * inv_rms);
out_row[i] = xn * weight[i];
}
}

Expand Down Expand Up @@ -142,6 +340,13 @@ fused_add_rms_norm_kernel(

} // namespace vllm

static inline int ln_block_threads_unified(int H) {
int threads = (H >= 1024) ? 256
: (H >= 512) ? 512
: std::min(1024, ((H + 31) / 32) * 32);
return std::min(1024, std::max(128, ((threads + 31) / 32) * 32));
}

void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
Expand All @@ -150,18 +355,32 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
TORCH_CHECK(input.stride(-1) == 1);
TORCH_CHECK(weight.is_contiguous());

int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
int64_t input_stride = input.stride(-2);
const int hidden_size = input.size(-1);
const int num_tokens = input.numel() / hidden_size;
const int64_t in_stride = input.stride(-2);

dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
dim3 block(ln_block_threads_unified(hidden_size));

const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// Optional per-block row cache in dynamic shared memory.
// If enabled (FP16 and HS <= 4096), the kernel copies the row to smem once
// and reuses it on the second pass to cut a global read. If shmem_bytes == 0,
// the kernel takes the non-cached path
size_t shmem_bytes = 0;
int smem_elems = 0;
if (input.scalar_type() == at::kHalf && hidden_size <= 4096) {
shmem_bytes = static_cast<size_t>(hidden_size) * sizeof(at::Half);
smem_elems = hidden_size; // flag to kernel that shmem was provisioned
}

VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
vllm::rms_norm_kernel<scalar_t><<<grid, block, shmem_bytes, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), in_stride,
weight.data_ptr<scalar_t>(), static_cast<float>(epsilon), num_tokens,
hidden_size, smem_elems);
});
}

Expand Down
Loading