-
-
Notifications
You must be signed in to change notification settings - Fork 9.8k
Vectorize RMSNorm CUDA kernel #22602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
This pull request was exported from Phabricator. Differential Revision: D79969610 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request significantly refactors the RMSNorm CUDA kernels to improve performance through vectorization, shared memory caching, and unified reduction logic. The benchmark results show a substantial speedup. The changes are well-structured and the optimizations follow common CUDA best practices. However, I found a critical integer overflow bug in the quantized fused kernel. While the input stride was correctly updated to int64_t
to support large tensors, the corresponding index variables were left as int
, which could lead to incorrect memory access and corrupted results.
2e9b93f
to
79faf5c
Compare
csrc/layernorm_kernels.cu
Outdated
template <typename T> | ||
__device__ __forceinline__ T warp_sum(T v) { | ||
#ifdef __HIP_PLATFORM_AMD__ | ||
const unsigned long long m = 0xffffffffffffffffull; // HIP needs 64-bit mask | ||
#else | ||
const unsigned m = 0xffffffffu; // CUDA 32-bit mask | ||
#endif | ||
// Always reduce over 32 lanes to match downstream logic. | ||
constexpr int kWidth = 32; | ||
v += __shfl_down_sync(m, v, 16, kWidth); | ||
v += __shfl_down_sync(m, v, 8, kWidth); | ||
v += __shfl_down_sync(m, v, 4, kWidth); | ||
v += __shfl_down_sync(m, v, 2, kWidth); | ||
v += __shfl_down_sync(m, v, 1, kWidth); | ||
return v; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: warp_sum
and ln_block_threads_unified
are the same for rms_norm_kernel and rms_norm_static_fp8_quant. Will move to shared helper after aligning on high level approach.
cc @yewentao256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work!
Could we also test for end to end performance & accuracy?
For accuracy: lm_eval ...
For E2E performance (optional, nice to have): vllm bench ...
csrc/layernorm_kernels.cu
Outdated
// -------- Fallback scalar -------- | ||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { | ||
float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]); | ||
scalar_t xn = static_cast<scalar_t>(x * inv_rms); | ||
out_row[i] = xn * weight[i]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kept the scalar fallback for now. The vectorization util (even V=1) peels a prefix and changes per-thread write order, which causes small FP8 mismatches vs the fused path on unaligned shapes. Happy to revisit once fused ordering matches or revise further if you have any preferences.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The scalar fallback inside the util is actually for some address misalignment, coping with some pre-align and tail align. Could you give me an example of mismatches? And we can see together whether we should update the util itself.
Thanks for the feedback @yewentao256. Will work on addressing and testing e2e. |
Summary: What: Make RMSNorm faster by reading data in bigger aligned chunks, caching fp16 rows in shared memory, making the FP8-quant version work with strided inputs, and using the same launch settings as the unfused path. Why: Cut global memory traffic using aligned vector inputs/outputs and shared-mem reuse (avoids second read), make the FP8 path safe for strided inputs, and preserve numerics by matching the unfused reduction/launch order. Test Plan: 1) Run tests ``` [[email protected] /data/users/benjibeck/fbsource/fbcode/vllm (1043a27694)]$ buck2 test :test_kernels_layernorm Buck UI: https://www.internalfb.com/buck2/054ebad3-ad92-4676-a4d2-3bf43e44f31a Test UI: https://www.internalfb.com/intern/testinfra/testrun/10414574240710255 Network: Up: 152MiB Down: 2.9GiB (reSessionID-14af330c-26bf-41d5-87b0-5775bf7d6f8a) Loading targets. Remaining 0/7 150 dirs read, 69 targets declared Analyzing targets. Remaining 0/32 772 actions, 819 artifacts declared Executing actions. Remaining 0/245 48.3s exec time total Command: test. Finished 1 local, 14 remote, 131 cache (90% hit) 45.2s exec time cached (93%) Time elapsed: 4:53.4s Tests finished: Pass 3169. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` 2) Run benchmark ``` buck run :benchmark_layernorm -- --num-tokens 16384 --hidden-size 1024 --dtype half --num-iters 500 ``` Performance (selected, non-strided / HND) | T | H | dtype | baseline (µs) | current (µs) | Δ | | 4096 | 1024 | fp16 | 24.592 | 16.552 | -32.7% | | 16384 | 1024 | fp16 | 106.699 | 42.739 | -60.0% | | 4096 | 8192 | fp16 | 118.566 | 97.059 | -18.1% | | 16384 | 8192 | fp16 | 450.738 | 356.125 | -21.0% | | 4096 | 1024 | bf16 | 24.743 | 16.683 | -32.6% | | 16384 | 1024 | bf16 | 107.009 | 56.946 | -46.8% | | 4096 | 8192 | bf16 | 119.293 | 96.774 | -18.9% | | 16384 | 8192 | bf16 | 451.181 | 357.799 | -20.7% | Strided (NHD) overhead (current kernel) — penalty vs. HND (same T/H/dtype): - 4096×1024 fp16: 1.39× (22.983 / 16.552) - 16384×1024 fp16: 2.13× (90.995 / 42.739) - 4096×8192 fp16: 1.93× (186.931 / 97.059) Rollback Plan: Differential Revision: D79969610
79faf5c
to
18ba33f
Compare
This pull request was exported from Phabricator. Differential Revision: D79969610 |
Summary: What: Make RMSNorm faster by reading data in bigger aligned chunks, caching fp16 rows in shared memory, making the FP8-quant version work with strided inputs, and using the same launch settings as the unfused path. Why: Cut global memory traffic using aligned vector inputs/outputs and shared-mem reuse (avoids second read), make the FP8 path safe for strided inputs, and preserve numerics by matching the unfused reduction/launch order. Test Plan: 1) Run tests ``` [[email protected] /data/users/benjibeck/fbsource/fbcode/vllm (1043a27694)]$ buck2 test :test_kernels_layernorm Buck UI: https://www.internalfb.com/buck2/054ebad3-ad92-4676-a4d2-3bf43e44f31a Test UI: https://www.internalfb.com/intern/testinfra/testrun/10414574240710255 Network: Up: 152MiB Down: 2.9GiB (reSessionID-14af330c-26bf-41d5-87b0-5775bf7d6f8a) Loading targets. Remaining 0/7 150 dirs read, 69 targets declared Analyzing targets. Remaining 0/32 772 actions, 819 artifacts declared Executing actions. Remaining 0/245 48.3s exec time total Command: test. Finished 1 local, 14 remote, 131 cache (90% hit) 45.2s exec time cached (93%) Time elapsed: 4:53.4s Tests finished: Pass 3169. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` 2) Run benchmark ``` buck run :benchmark_layernorm -- --num-tokens 16384 --hidden-size 1024 --dtype half --num-iters 500 ``` Performance (selected, non-strided / HND) | T | H | dtype | baseline (µs) | current (µs) | Δ | | 4096 | 1024 | fp16 | 24.592 | 16.552 | -32.7% | | 16384 | 1024 | fp16 | 106.699 | 42.739 | -60.0% | | 4096 | 8192 | fp16 | 118.566 | 97.059 | -18.1% | | 16384 | 8192 | fp16 | 450.738 | 356.125 | -21.0% | | 4096 | 1024 | bf16 | 24.743 | 16.683 | -32.6% | | 16384 | 1024 | bf16 | 107.009 | 56.946 | -46.8% | | 4096 | 8192 | bf16 | 119.293 | 96.774 | -18.9% | | 16384 | 8192 | bf16 | 451.181 | 357.799 | -20.7% | Strided (NHD) overhead (current kernel) — penalty vs. HND (same T/H/dtype): - 4096×1024 fp16: 1.39× (22.983 / 16.552) - 16384×1024 fp16: 2.13× (90.995 / 42.739) - 4096×8192 fp16: 1.93× (186.931 / 97.059) Rollback Plan: Differential Revision: D79969610
18ba33f
to
4b85028
Compare
This pull request was exported from Phabricator. Differential Revision: D79969610 |
Summary: What: Make RMSNorm faster by reading data in bigger aligned chunks, caching fp16 rows in shared memory, making the FP8-quant version work with strided inputs, and using the same launch settings as the unfused path. Why: Cut global memory traffic using aligned vector inputs/outputs and shared-mem reuse (avoids second read), make the FP8 path safe for strided inputs, and preserve numerics by matching the unfused reduction/launch order. Test Plan: 1) Run tests ``` [[email protected] /data/users/benjibeck/fbsource/fbcode/vllm (1043a27694)]$ buck2 test :test_kernels_layernorm Buck UI: https://www.internalfb.com/buck2/054ebad3-ad92-4676-a4d2-3bf43e44f31a Test UI: https://www.internalfb.com/intern/testinfra/testrun/10414574240710255 Network: Up: 152MiB Down: 2.9GiB (reSessionID-14af330c-26bf-41d5-87b0-5775bf7d6f8a) Loading targets. Remaining 0/7 150 dirs read, 69 targets declared Analyzing targets. Remaining 0/32 772 actions, 819 artifacts declared Executing actions. Remaining 0/245 48.3s exec time total Command: test. Finished 1 local, 14 remote, 131 cache (90% hit) 45.2s exec time cached (93%) Time elapsed: 4:53.4s Tests finished: Pass 3169. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` 2) Run benchmark ``` buck run :benchmark_layernorm -- --num-tokens 16384 --hidden-size 1024 --dtype half --num-iters 500 ``` Performance (selected, non-strided / HND) | T | H | dtype | baseline (µs) | current (µs) | Δ | | 4096 | 1024 | fp16 | 24.592 | 16.552 | -32.7% | | 16384 | 1024 | fp16 | 106.699 | 42.739 | -60.0% | | 4096 | 8192 | fp16 | 118.566 | 97.059 | -18.1% | | 16384 | 8192 | fp16 | 450.738 | 356.125 | -21.0% | | 4096 | 1024 | bf16 | 24.743 | 16.683 | -32.6% | | 16384 | 1024 | bf16 | 107.009 | 56.946 | -46.8% | | 4096 | 8192 | bf16 | 119.293 | 96.774 | -18.9% | | 16384 | 8192 | bf16 | 451.181 | 357.799 | -20.7% | Strided (NHD) overhead (current kernel) — penalty vs. HND (same T/H/dtype): - 4096×1024 fp16: 1.39× (22.983 / 16.552) - 16384×1024 fp16: 2.13× (90.995 / 42.739) - 4096×8192 fp16: 1.93× (186.931 / 97.059) Rollback Plan: Differential Revision: D79969610 Signed-off-by: Benji Beck <[email protected]>
4b85028
to
16290a7
Compare
Signed-off-by: Benji Beck <[email protected]>
Update: Performed e2e benchmark for latency using x2 H100 and observed ~4.8% reduction with vectorized rms_norm_kernel without shared memory. I'm planning to benchmark further with other models/configurations to understand this better. Let me know if you folks have any thoughts/preferences.
|
Ran quick E2E latency on a few models (bs=128, in=32, out=128). Overall: small win on Llama-3.2, neutral on Qwen fp16, and Gemma-2-9B works better if we skip the shared-mem cached row. I can rerun with further settings or add other models if there's any preferences. @yewentao256 Wondering if we should consider adding a knob to disable the cached-row path where it regresses for some models?
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: These changes were made for testing within Meta infra and will be deleted before landing.
cc @mgoin @WoosukKwon for any additional thoughts. |
Hi @yewentao256, just following up in case you have time to review. Would be great if we could align on these results before making further changes. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late response, there are some problems with my github notifications but now fixed.
The overall work I think is good, could you
- test the accuracy using
lm_eval
? - benchmark the kernel directly?
One example could be seen here: #22036
Summary:
• Vectorize writes when in/out phases align; scalar reads for weight when needed
• Use an optional fp16 shared-mem cache (avoids second read)
• Preserve numerics on odd/strided shapes by keeping scalar fallback
• Match eunfused launch/reduction order
Test Plan:
Performance (selected, non-strided / HND)
Strided (NHD) overhead (current kernel) — penalty vs. HND (same T/H/dtype):
Differential Revision: D79969610