diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py index 69978ec6b23e..037ca681be83 100644 --- a/benchmarks/kernels/benchmark_layernorm.py +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -11,7 +11,7 @@ @torch.inference_mode() -def main( +def run_benchmark( num_tokens: int, hidden_size: int, add_residual: bool, @@ -59,7 +59,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: print(f"Kernel running time: {latency * 1000000:.3f} us") -if __name__ == "__main__": +def main(): + """Main function for Buck compatibility.""" parser = FlexibleArgumentParser(description="Benchmark the layernorm kernel.") parser.add_argument("--num-tokens", type=int, default=4096) parser.add_argument("--hidden-size", type=int, default=8192) @@ -81,7 +82,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: args = parser.parse_args() print(args) - main( + run_benchmark( num_tokens=args.num_tokens, hidden_size=args.hidden_size, add_residual=args.add_residual, @@ -91,3 +92,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: num_warmup_iters=args.num_warmup_iters, num_iters=args.num_iters, ) + + +if __name__ == "__main__": + main() diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index f051eb070222..2ad14f83c1ad 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -3,6 +3,7 @@ #include #include +#include "quantization/vectorization_utils.cuh" #ifndef USE_ROCM #include @@ -12,35 +13,254 @@ namespace vllm { -// TODO(woosuk): Further optimize this kernel. +// ---------- helpers ---------- +template +__device__ __forceinline__ T warp_sum(T v) { +#ifdef __HIP_PLATFORM_AMD__ + const unsigned long long m = 0xffffffffffffffffull; // HIP needs 64-bit mask +#else + const unsigned m = 0xffffffffu; // CUDA 32-bit mask +#endif + // Always reduce over 32 lanes to match downstream logic. + constexpr int kWidth = 32; + v += __shfl_down_sync(m, v, 16, kWidth); + v += __shfl_down_sync(m, v, 8, kWidth); + v += __shfl_down_sync(m, v, 4, kWidth); + v += __shfl_down_sync(m, v, 2, kWidth); + v += __shfl_down_sync(m, v, 1, kWidth); + return v; +} + +template +__device__ __forceinline__ bool same_phase(const T* a, const T* b, int widthB) { + auto ai = reinterpret_cast(a); + auto bi = reinterpret_cast(b); + return ((ai ^ bi) & (widthB - 1)) == 0; +} + +// Safe 16B copy to shared: prefix to align, vector main, scalar tail. +template +__device__ __forceinline__ void copy_row_to_shared_aligned( + const T* __restrict__ src, T* __restrict__ dst, int n_elems, int tid) { + const uintptr_t sa = reinterpret_cast(src); + const uintptr_t da = reinterpret_cast(dst); + const bool same16 = (((sa ^ da) & (16 - 1)) == 0); + + if (!same16) { + for (int i = tid; i < n_elems; i += blockDim.x) dst[i] = src[i]; + __syncthreads(); + return; + } + const int ebytes = sizeof(T); + const int perVec = 16 / ebytes; + + int prefix = 0; + const int mis = sa & (16 - 1); + if (mis) prefix = (16 - mis) / ebytes; + if (prefix > n_elems) prefix = n_elems; + + // scalar prefix + for (int i = tid; i < prefix; i += blockDim.x) dst[i] = src[i]; + + // vector main + const int remain = n_elems - prefix; + const int main_elems = (remain / perVec) * perVec; + if (main_elems > 0) { + const uint4* __restrict__ vsrc = + reinterpret_cast(src + prefix); + +#if defined(__HIP_PLATFORM_AMD__) + // ROCm: vector load from global, scalar 32-bit stores to shared + uint32_t* __restrict__ s32 = reinterpret_cast(dst + prefix); + const int nvec = main_elems / perVec; // 16B packets + constexpr int WORDS_PER_PKT = 16 / sizeof(uint32_t); // = 4 + for (int v = tid; v < nvec; v += blockDim.x) { + uint4 p = vsrc[v]; + const int base = v * WORDS_PER_PKT; + s32[base + 0] = p.x; + s32[base + 1] = p.y; + s32[base + 2] = p.z; + s32[base + 3] = p.w; + } +#else + // CUDA: vector load + vector store (fastest) + uint4* __restrict__ vdst = reinterpret_cast(dst + prefix); + const int nvec = main_elems / perVec; + for (int v = tid; v < nvec; v += blockDim.x) { + uint4 p = vsrc[v]; + vdst[v] = p; + } +#endif + } + + // scalar tail + const int tail = prefix + main_elems; + for (int i = tid + tail; i < n_elems; i += blockDim.x) dst[i] = src[i]; + __syncthreads(); +} + +// ---------------- vec/scalar ops (generic, used for all dtypes) +// ---------------- +template +struct VecMulNormWeight { + const vec_n_t* __restrict__ wv; // vector view of weight (aligned with + // in/out) + float inv_rms; + int stride_vec; + mutable int64_t vec_idx; + + __device__ __forceinline__ void operator()(vec_n_t& dst, + const vec_n_t& src) const { + const vec_n_t w = wv[vec_idx]; +#pragma unroll + for (int j = 0; j < V; ++j) { + T xn = static_cast(static_cast(src.val[j]) * inv_rms); + dst.val[j] = xn * w.val[j]; + } + vec_idx += stride_vec; + } +}; + +template +struct ScalarMulNormWeight { + const T* __restrict__ w_base; // already offset by +prefix + T* __restrict__ out_base; // out_row + prefix + float inv_rms; + __device__ __forceinline__ void operator()(T& dst, const T src) const { + const int i = static_cast(&dst - out_base); + T xn = static_cast(static_cast(src) * inv_rms); + dst = xn * w_base[i]; + } +}; + +// ---------- kernel ---------- template __global__ void rms_norm_kernel( scalar_t* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] const int64_t input_stride, const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, const int num_tokens, const int hidden_size) { - __shared__ float s_variance; - float variance = 0.0f; + const float epsilon, const int /*num_tokens*/, const int hidden_size, + int smem_elems) { + const scalar_t* __restrict__ in_row = input + blockIdx.x * input_stride; + scalar_t* __restrict__ out_row = out + blockIdx.x * hidden_size; - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float)input[blockIdx.x * input_stride + idx]; - variance += x * x; + // Optional cached-row (half) when host provisioned shmem + extern __shared__ unsigned char smem_raw[]; + scalar_t* s_in = reinterpret_cast(smem_raw); + +#ifdef __HIP_PLATFORM_AMD__ + constexpr bool kAllowCache = false; +#else + constexpr bool kAllowCache = true; +#endif + const bool use_cached = + kAllowCache && (sizeof(scalar_t) == 2) && (smem_elems > 0); + +#if !defined(__HIP_PLATFORM_AMD__) + if (use_cached) { + copy_row_to_shared_aligned(in_row, s_in, hidden_size, threadIdx.x); } +#endif - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + // -------- Pass 1: sum of squares -------- + using acc_t = float; + acc_t sumsq = acc_t(0); + if (use_cached) { + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + acc_t x = static_cast(s_in[i]); + sumsq += x * x; + } + } else { + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + acc_t x = static_cast(in_row[i]); + sumsq += x * x; + } + } - if (threadIdx.x == 0) { - s_variance = rsqrtf(variance / hidden_size + epsilon); + // warp + block reduction in acc_t + acc_t wsum = warp_sum(sumsq); + __shared__ acc_t warp_sums_sh[32]; + if ((threadIdx.x & 31) == 0) warp_sums_sh[threadIdx.x >> 5] = wsum; + __syncthreads(); + + acc_t total = acc_t(0); + if (threadIdx.x < 32) { + acc_t v = (threadIdx.x < (blockDim.x + 31) / 32) ? warp_sums_sh[threadIdx.x] + : acc_t(0); + total = warp_sum(v); + if (threadIdx.x == 0) warp_sums_sh[0] = total; } __syncthreads(); - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * input_stride + idx]; - out[blockIdx.x * hidden_size + idx] = - ((scalar_t)(x * s_variance)) * weight[idx]; + // compute inv_rms in float to match baseline epsilon semantics + const float inv_rms = rsqrtf( + static_cast(warp_sums_sh[0] / static_cast(hidden_size)) + + epsilon); + + // -------- Fast path: HS == blockDim.x (e.g., 1024) -------- + if (hidden_size == blockDim.x) { + int i = threadIdx.x; + float x = static_cast(use_cached ? s_in[i] : in_row[i]); + scalar_t xn = static_cast(x * inv_rms); + out_row[i] = xn * weight[i]; + return; + } + + // -------- Pass 2: Vectorize when phases align -------- + constexpr int V = (sizeof(scalar_t) == 2) ? 8 : 4; // 16B packets + constexpr int WIDTH = V * sizeof(scalar_t); + + const bool can_vec = (hidden_size % V == 0) && + same_phase(in_row, out_row, WIDTH) && + same_phase(in_row, weight, WIDTH); + + if (can_vec) { + const uintptr_t addr = reinterpret_cast(in_row); + const int mis = addr & (WIDTH - 1); + const int prefix = mis ? (WIDTH - mis) / (int)sizeof(scalar_t) : 0; + + // scalar prefix + for (int i = threadIdx.x; i < prefix; i += blockDim.x) { + float x = static_cast(use_cached ? s_in[i] : in_row[i]); + out_row[i] = static_cast(x * inv_rms) * weight[i]; + } + + // vector main + const int remain = hidden_size - prefix; + const int main_len = (remain / V) * V; + if (main_len > 0) { + using VecT = vec_n_t; + const VecT* __restrict__ wv = + reinterpret_cast(weight + prefix); + + VecMulNormWeight vec_op{/*wv=*/wv, + /*inv_rms=*/inv_rms, + /*stride_vec=*/(int)blockDim.x, + /*vec_idx=*/(int64_t)threadIdx.x}; + ScalarMulNormWeight sca_op{/*w_base=*/weight + prefix, + /*out_base=*/out_row + prefix, + /*inv_rms=*/inv_rms}; + + const scalar_t* vin = use_cached ? (s_in + prefix) : (in_row + prefix); + vectorize_with_alignment(vin, out_row + prefix, main_len, threadIdx.x, + blockDim.x, vec_op, sca_op); + } + + // scalar tail + const int tail = prefix + ((remain / V) * V); + for (int i = threadIdx.x + tail; i < hidden_size; i += blockDim.x) { + float x = static_cast(use_cached ? s_in[i] : in_row[i]); + out_row[i] = static_cast(x * inv_rms) * weight[i]; + } + return; + } + + // -------- Fallback scalar -------- + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = static_cast(use_cached ? s_in[i] : in_row[i]); + scalar_t xn = static_cast(x * inv_rms); + out_row[i] = xn * weight[i]; } } @@ -142,6 +362,13 @@ fused_add_rms_norm_kernel( } // namespace vllm +static inline int ln_block_threads_unified(int H) { + int threads = (H >= 1024) ? 256 + : (H >= 512) ? 512 + : std::min(1024, ((H + 31) / 32) * 32); + return std::min(1024, std::max(128, ((threads + 31) / 32) * 32)); +} + void rms_norm(torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] @@ -150,18 +377,30 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] TORCH_CHECK(input.stride(-1) == 1); TORCH_CHECK(weight.is_contiguous()); - int hidden_size = input.size(-1); - int num_tokens = input.numel() / hidden_size; - int64_t input_stride = input.stride(-2); + const int hidden_size = input.size(-1); + const int num_tokens = input.numel() / hidden_size; + const int64_t in_stride = input.stride(-2); dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 1024)); + dim3 block(ln_block_threads_unified(hidden_size)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Optional cached-row for FP16 (recommended). Kernel still works if this is + // 0. + size_t shmem_bytes = 0; + int smem_elems = 0; + if (input.scalar_type() == at::kHalf && hidden_size <= 4096) { + shmem_bytes = static_cast(hidden_size) * sizeof(at::Half); + smem_elems = hidden_size; // flag to kernel that shmem was provisioned + } + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), input.data_ptr(), input_stride, - weight.data_ptr(), epsilon, num_tokens, hidden_size); + vllm::rms_norm_kernel<<>>( + out.data_ptr(), input.data_ptr(), in_stride, + weight.data_ptr(), static_cast(epsilon), num_tokens, + hidden_size, smem_elems); }); } diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 0fd5849d9626..51b57360b082 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -20,40 +20,67 @@ namespace vllm { -// TODO(woosuk): Further optimize this kernel. +template +__device__ __forceinline__ T warp_sum(T v) { +#ifdef __HIP_PLATFORM_AMD__ + const unsigned long long m = 0xffffffffffffffffull; // HIP needs 64-bit mask +#else + const unsigned m = 0xffffffffu; +#endif + constexpr int kWidth = 32; // keep reduction over 32 lanes everywhere + v += __shfl_down_sync(m, v, 16, kWidth); + v += __shfl_down_sync(m, v, 8, kWidth); + v += __shfl_down_sync(m, v, 4, kWidth); + v += __shfl_down_sync(m, v, 2, kWidth); + v += __shfl_down_sync(m, v, 1, kWidth); + return v; +} + +// kernel unchanged (uses warp_sum + same math as unfused) template __global__ void rms_norm_static_fp8_quant_kernel( - fp8_type* __restrict__ out, // [..., hidden_size] - const scalar_t* __restrict__ input, // [..., hidden_size] - const int input_stride, - const scalar_t* __restrict__ weight, // [hidden_size] + fp8_type* __restrict__ out, // [T, H] + const scalar_t* __restrict__ input, // [T, last_dim], may be strided + const int64_t input_stride, // <-- int64_t + const scalar_t* __restrict__ weight, // [H] const float* __restrict__ scale, // [1] - const float epsilon, const int num_tokens, const int hidden_size) { - __shared__ float s_variance; - float variance = 0.0f; + const float epsilon, const int /*num_tokens*/, const int hidden_size) { + const scalar_t* __restrict__ in_row = input + blockIdx.x * input_stride; - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float)input[blockIdx.x * input_stride + idx]; - variance += x * x; + using acc_t = float; + acc_t sumsq = acc_t(0); + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + acc_t x = static_cast(in_row[i]); + sumsq += x * x; } - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + // identical reduction to unfused + acc_t wsum = warp_sum(sumsq); + __shared__ acc_t warp_sums_sh[32]; + if ((threadIdx.x & 31) == 0) warp_sums_sh[threadIdx.x >> 5] = wsum; + __syncthreads(); - if (threadIdx.x == 0) { - s_variance = rsqrtf(variance / hidden_size + epsilon); + if (threadIdx.x < 32) { + acc_t v = (threadIdx.x < (blockDim.x + 31) / 32) ? warp_sums_sh[threadIdx.x] + : acc_t(0); + acc_t total = warp_sum(v); + if (threadIdx.x == 0) warp_sums_sh[0] = total; } __syncthreads(); - // invert scale to avoid division - float const scale_inv = 1.0f / *scale; + const float inv_rms = rsqrtf( + static_cast(warp_sums_sh[0] / static_cast(hidden_size)) + + epsilon); - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * input_stride + idx]; - float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; - out[blockIdx.x * hidden_size + idx] = - scaled_fp8_conversion(out_norm, scale_inv); + const float scale_inv = 1.0f / (*scale); + + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + const float x_f = static_cast(in_row[i]); + const scalar_t xn = + static_cast(x_f * inv_rms); // fp32 normalize → cast to T + const scalar_t z = xn * weight[i]; // multiply in T + out[blockIdx.x * hidden_size + i] = + scaled_fp8_conversion(static_cast(z), scale_inv); } } @@ -66,7 +93,7 @@ __global__ std::enable_if_t<(width > 0) && _typeConvert::exists> fused_add_rms_norm_static_fp8_quant_kernel( fp8_type* __restrict__ out, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size] - const int input_stride, + const int64_t input_stride, scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float* __restrict__ scale, // [1] @@ -76,7 +103,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); const int vec_hidden_size = hidden_size / width; - const int vec_input_stride = input_stride / width; + const int64_t vec_input_stride = input_stride / width; __shared__ float s_variance; float variance = 0.0f; /* These and the argument pointers are all declared `restrict` as they are @@ -131,7 +158,7 @@ __global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> fused_add_rms_norm_static_fp8_quant_kernel( fp8_type* __restrict__ out, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size] - const int input_stride, + const int64_t input_stride, scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float* __restrict__ scale, // [1] @@ -169,20 +196,37 @@ fused_add_rms_norm_static_fp8_quant_kernel( } // namespace vllm -void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] +// --- shared: match unfused launch exactly --- +static inline int ln_block_threads_unified(int hidden_size) { + int threads = (hidden_size >= 1024) ? 256 + : (hidden_size >= 512) + ? 512 + : std::min(1024, ((hidden_size + 31) / 32) * 32); + // warp-align and clamp to [128, 1024] + threads = std::min(1024, std::max(128, ((threads + 31) / 32) * 32)); + return threads; +} + +void rms_norm_static_fp8_quant(torch::Tensor& out, // [T, H] + torch::Tensor& input, // [T, last_dim] + torch::Tensor& weight, // [H] torch::Tensor& scale, // [1] double epsilon) { TORCH_CHECK(out.is_contiguous()); - int hidden_size = input.size(-1); - int input_stride = input.stride(-2); - int num_tokens = input.numel() / hidden_size; + TORCH_CHECK(weight.is_contiguous()); + TORCH_CHECK(input.stride(-1) == 1, "last dim must be contiguous"); + + const int hidden_size = input.size(-1); + const int64_t input_stride = + input.stride(-2); // row stride (== last_dim when 2D) + const int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 1024)); + dim3 block(ln_block_threads_unified(hidden_size)); // <-- match unfused + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "rms_norm_kernel_scalar_type", [&] { VLLM_DISPATCH_FP8_TYPES( @@ -191,10 +235,11 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] <<>>( out.data_ptr(), input.data_ptr(), input_stride, weight.data_ptr(), - scale.data_ptr(), epsilon, num_tokens, - hidden_size); + scale.data_ptr(), static_cast(epsilon), + num_tokens, hidden_size); }); }); + // TORCH_CUDA_KERNEL_LAUNCH_CHECK(); } #define LAUNCH_FUSED_ADD_RMS_NORM(width) \ @@ -221,7 +266,7 @@ void fused_add_rms_norm_static_fp8_quant( TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(residual.is_contiguous()); int hidden_size = input.size(-1); - int input_stride = input.stride(-2); + int64_t input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens);