-
-
Notifications
You must be signed in to change notification settings - Fork 9.3k
Vectorize RMSNorm CUDA kernel #22602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
This pull request was exported from Phabricator. Differential Revision: D79969610 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request significantly refactors the RMSNorm CUDA kernels to improve performance through vectorization, shared memory caching, and unified reduction logic. The benchmark results show a substantial speedup. The changes are well-structured and the optimizations follow common CUDA best practices. However, I found a critical integer overflow bug in the quantized fused kernel. While the input stride was correctly updated to int64_t
to support large tensors, the corresponding index variables were left as int
, which could lead to incorrect memory access and corrupted results.
@@ -76,7 +105,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( | |||
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width); | |||
|
|||
const int vec_hidden_size = hidden_size / width; | |||
const int vec_input_stride = input_stride / width; | |||
const int64_t vec_input_stride = input_stride / width; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing vec_input_stride
to int64_t
is correct, but this change is incomplete and introduces a potential integer overflow. The variables stride_id
and id
on lines 122-123, which are used for indexing, are still declared as int
.
For large tensors where num_tokens
(and thus blockIdx.x
) is large, the multiplication blockIdx.x * vec_input_stride
can exceed the maximum value of an int
, leading to an overflow. This would result in incorrect memory accesses and corrupted output.
To fix this, stride_id
and id
should also be declared as int64_t
:
int64_t stride_id = blockIdx.x * vec_input_stride + idx;
int64_t id = blockIdx.x * vec_hidden_size + idx;
Summary: What: Make RMSNorm faster by reading data in bigger aligned chunks, caching fp16 rows in shared memory, making the FP8-quant version work with strided inputs, and using the same launch settings as the unfused path. Why: Cut global memory traffic using aligned vector inputs/outputs and shared-mem reuse (avoids second read), make the FP8 path safe for strided inputs, and preserve numerics by matching the unfused reduction/launch order. Test Plan: 1) Run tests ``` [[email protected] /data/users/benjibeck/fbsource/fbcode/vllm (1043a27694)]$ buck2 test :test_kernels_layernorm Buck UI: https://www.internalfb.com/buck2/054ebad3-ad92-4676-a4d2-3bf43e44f31a Test UI: https://www.internalfb.com/intern/testinfra/testrun/10414574240710255 Network: Up: 152MiB Down: 2.9GiB (reSessionID-14af330c-26bf-41d5-87b0-5775bf7d6f8a) Loading targets. Remaining 0/7 150 dirs read, 69 targets declared Analyzing targets. Remaining 0/32 772 actions, 819 artifacts declared Executing actions. Remaining 0/245 48.3s exec time total Command: test. Finished 1 local, 14 remote, 131 cache (90% hit) 45.2s exec time cached (93%) Time elapsed: 4:53.4s Tests finished: Pass 3169. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` 2) Run benchmark ``` buck run :benchmark_layernorm -- --num-tokens 16384 --hidden-size 1024 --dtype half --num-iters 500 Before -> Kernel running time: 105.918 us After -> Kernel running time: 42.571 us ``` Rollback Plan: Differential Revision: D79969610 Signed-off-by: Benji Beck <[email protected]>
Signed-off-by: Benji Beck <[email protected]>
2e9b93f
to
79faf5c
Compare
template <typename T> | ||
__device__ __forceinline__ T warp_sum(T v) { | ||
#ifdef __HIP_PLATFORM_AMD__ | ||
const unsigned long long m = 0xffffffffffffffffull; // HIP needs 64-bit mask | ||
#else | ||
const unsigned m = 0xffffffffu; // CUDA 32-bit mask | ||
#endif | ||
// Always reduce over 32 lanes to match downstream logic. | ||
constexpr int kWidth = 32; | ||
v += __shfl_down_sync(m, v, 16, kWidth); | ||
v += __shfl_down_sync(m, v, 8, kWidth); | ||
v += __shfl_down_sync(m, v, 4, kWidth); | ||
v += __shfl_down_sync(m, v, 2, kWidth); | ||
v += __shfl_down_sync(m, v, 1, kWidth); | ||
return v; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: warp_sum
and ln_block_threads_unified
are the same for rms_norm_kernel and rms_norm_static_fp8_quant. Will move to shared helper after aligning on high level approach.
cc @yewentao256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work!
Could we also test for end to end performance & accuracy?
For accuracy: lm_eval ...
For E2E performance (optional, nice to have): vllm bench ...
// -------- Pass 2: Vectorize when phases align -------- | ||
constexpr int V = (sizeof(scalar_t) == 2) ? 8 : 4; // 16B packets | ||
constexpr int WIDTH = V * sizeof(scalar_t); | ||
|
||
const bool can_vec = (hidden_size % V == 0) && | ||
same_phase(in_row, out_row, WIDTH) && | ||
same_phase(in_row, weight, WIDTH); | ||
|
||
if (can_vec) { | ||
const uintptr_t addr = reinterpret_cast<uintptr_t>(in_row); | ||
const int mis = addr & (WIDTH - 1); | ||
const int prefix = mis ? (WIDTH - mis) / (int)sizeof(scalar_t) : 0; | ||
|
||
// scalar prefix | ||
for (int i = threadIdx.x; i < prefix; i += blockDim.x) { | ||
float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]); | ||
out_row[i] = static_cast<scalar_t>(x * inv_rms) * weight[i]; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have the duplicate logic in vllm/csrc/quantization/vectorization_utils.cuh
, could we reuse the util?
// -------- Fallback scalar -------- | ||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { | ||
float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]); | ||
scalar_t xn = static_cast<scalar_t>(x * inv_rms); | ||
out_row[i] = xn * weight[i]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same above
int hidden_size = input.size(-1); | ||
int input_stride = input.stride(-2); | ||
int64_t input_stride = input.stride(-2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also add some comments here why we need to update to int64
Summary:
What: Make RMSNorm faster by reading data in bigger aligned chunks, caching fp16 rows in shared memory, making the FP8-quant version work with strided inputs, and using the same launch settings as the unfused path.
Why: Cut global memory traffic using aligned vector inputs/outputs and shared-mem reuse (avoids second read), make the FP8 path safe for strided inputs, and preserve numerics by matching the unfused reduction/launch order.
Test Plan:
Performance (selected, non-strided / HND)
Strided (NHD) overhead (current kernel) — penalty vs. HND (same T/H/dtype):
Differential Revision: D79969610