Skip to content

Vectorize RMSNorm CUDA kernel #22602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

bbeckca
Copy link
Contributor

@bbeckca bbeckca commented Aug 10, 2025

Summary:
What: Make RMSNorm faster by reading data in bigger aligned chunks, caching fp16 rows in shared memory, making the FP8-quant version work with strided inputs, and using the same launch settings as the unfused path.

Why: Cut global memory traffic using aligned vector inputs/outputs and shared-mem reuse (avoids second read), make the FP8 path safe for strided inputs, and preserve numerics by matching the unfused reduction/launch order.

Test Plan:

  1. Run tests
[[email protected] /data/users/benjibeck/fbsource/fbcode/vllm (1043a27694)]$ buck2 test :test_kernels_layernorm
Buck UI: https://www.internalfb.com/buck2/054ebad3-ad92-4676-a4d2-3bf43e44f31a
Test UI: https://www.internalfb.com/intern/testinfra/testrun/10414574240710255
Network: Up: 152MiB  Down: 2.9GiB  (reSessionID-14af330c-26bf-41d5-87b0-5775bf7d6f8a)
Loading targets.   Remaining     0/7                                                                                           150 dirs read, 69 targets declared
Analyzing targets. Remaining     0/32                                                                                          772 actions, 819 artifacts declared
Executing actions. Remaining     0/245                                                                                         48.3s exec time total
Command: test.     Finished 1 local, 14 remote, 131 cache (90% hit)                                                            45.2s exec time cached (93%)                                         
Time elapsed: 4:53.4s
Tests finished: Pass 3169. Fail 0. Fatal 0. Skip 0. Build failure 0
  1. Run benchmark
buck run :benchmark_layernorm -- --num-tokens 16384 --hidden-size 1024 --dtype half --num-iters 500

Performance (selected, non-strided / HND)

T H dtype baseline (µs) current (µs) Δ
4096 1024 fp16 24.592 16.552 -32.7%
16384 1024 fp16 106.699 42.739 -60.0%
4096 8192 fp16 118.566 97.059 -18.1%
16384 8192 fp16 450.738 356.125 -21.0%
4096 1024 bf16 24.743 16.683 -32.6%
16384 1024 bf16 107.009 56.946 -46.8%
4096 8192 bf16 119.293 96.774 -18.9%
16384 8192 bf16 451.181 357.799 -20.7%

Strided (NHD) overhead (current kernel) — penalty vs. HND (same T/H/dtype):

  • 4096×1024 fp16: 1.39× (22.983 / 16.552)
  • 16384×1024 fp16: 2.13× (90.995 / 42.739)
  • 4096×8192 fp16: 1.93× (186.931 / 97.059)

Differential Revision: D79969610

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D79969610

@mergify mergify bot added the performance Performance-related issues label Aug 10, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly refactors the RMSNorm CUDA kernels to improve performance through vectorization, shared memory caching, and unified reduction logic. The benchmark results show a substantial speedup. The changes are well-structured and the optimizations follow common CUDA best practices. However, I found a critical integer overflow bug in the quantized fused kernel. While the input stride was correctly updated to int64_t to support large tensors, the corresponding index variables were left as int, which could lead to incorrect memory access and corrupted results.

@@ -76,7 +105,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);

const int vec_hidden_size = hidden_size / width;
const int vec_input_stride = input_stride / width;
const int64_t vec_input_stride = input_stride / width;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Changing vec_input_stride to int64_t is correct, but this change is incomplete and introduces a potential integer overflow. The variables stride_id and id on lines 122-123, which are used for indexing, are still declared as int.

For large tensors where num_tokens (and thus blockIdx.x) is large, the multiplication blockIdx.x * vec_input_stride can exceed the maximum value of an int, leading to an overflow. This would result in incorrect memory accesses and corrupted output.

To fix this, stride_id and id should also be declared as int64_t:

    int64_t stride_id = blockIdx.x * vec_input_stride + idx;
    int64_t id = blockIdx.x * vec_hidden_size + idx;

Summary:
What: Make RMSNorm faster by reading data in bigger aligned chunks, caching fp16 rows in shared memory, making the FP8-quant version work with strided inputs, and using the same launch settings as the unfused path.

Why: Cut global memory traffic using aligned vector inputs/outputs and shared-mem reuse (avoids second read), make the FP8 path safe for strided inputs, and preserve numerics by matching the unfused reduction/launch order.

Test Plan:
1) Run tests

```
[[email protected] /data/users/benjibeck/fbsource/fbcode/vllm (1043a27694)]$ buck2 test :test_kernels_layernorm
Buck UI: https://www.internalfb.com/buck2/054ebad3-ad92-4676-a4d2-3bf43e44f31a
Test UI: https://www.internalfb.com/intern/testinfra/testrun/10414574240710255
Network: Up: 152MiB  Down: 2.9GiB  (reSessionID-14af330c-26bf-41d5-87b0-5775bf7d6f8a)
Loading targets.   Remaining     0/7                                                                                           150 dirs read, 69 targets declared
Analyzing targets. Remaining     0/32                                                                                          772 actions, 819 artifacts declared
Executing actions. Remaining     0/245                                                                                         48.3s exec time total
Command: test.     Finished 1 local, 14 remote, 131 cache (90% hit)                                                            45.2s exec time cached (93%)
Time elapsed: 4:53.4s
Tests finished: Pass 3169. Fail 0. Fatal 0. Skip 0. Build failure 0
```

2) Run benchmark
```
buck run :benchmark_layernorm -- --num-tokens 16384 --hidden-size 1024 --dtype half --num-iters 500

Before -> Kernel running time: 105.918 us
After -> Kernel running time: 42.571 us
```

Rollback Plan:

Differential Revision: D79969610

Signed-off-by: Benji Beck <[email protected]>
Signed-off-by: Benji Beck <[email protected]>
Comment on lines +17 to +32
template <typename T>
__device__ __forceinline__ T warp_sum(T v) {
#ifdef __HIP_PLATFORM_AMD__
const unsigned long long m = 0xffffffffffffffffull; // HIP needs 64-bit mask
#else
const unsigned m = 0xffffffffu; // CUDA 32-bit mask
#endif
// Always reduce over 32 lanes to match downstream logic.
constexpr int kWidth = 32;
v += __shfl_down_sync(m, v, 16, kWidth);
v += __shfl_down_sync(m, v, 8, kWidth);
v += __shfl_down_sync(m, v, 4, kWidth);
v += __shfl_down_sync(m, v, 2, kWidth);
v += __shfl_down_sync(m, v, 1, kWidth);
return v;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: warp_sum and ln_block_threads_unified are the same for rms_norm_kernel and rms_norm_static_fp8_quant. Will move to shared helper after aligning on high level approach.

@mgoin
Copy link
Member

mgoin commented Aug 12, 2025

cc @yewentao256

Copy link
Collaborator

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!

Could we also test for end to end performance & accuracy?

For accuracy: lm_eval ...

For E2E performance (optional, nice to have): vllm bench ...

Comment on lines +210 to +227
// -------- Pass 2: Vectorize when phases align --------
constexpr int V = (sizeof(scalar_t) == 2) ? 8 : 4; // 16B packets
constexpr int WIDTH = V * sizeof(scalar_t);

const bool can_vec = (hidden_size % V == 0) &&
same_phase(in_row, out_row, WIDTH) &&
same_phase(in_row, weight, WIDTH);

if (can_vec) {
const uintptr_t addr = reinterpret_cast<uintptr_t>(in_row);
const int mis = addr & (WIDTH - 1);
const int prefix = mis ? (WIDTH - mis) / (int)sizeof(scalar_t) : 0;

// scalar prefix
for (int i = threadIdx.x; i < prefix; i += blockDim.x) {
float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]);
out_row[i] = static_cast<scalar_t>(x * inv_rms) * weight[i];
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have the duplicate logic in vllm/csrc/quantization/vectorization_utils.cuh, could we reuse the util?

Comment on lines +259 to +263
// -------- Fallback scalar --------
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]);
scalar_t xn = static_cast<scalar_t>(x * inv_rms);
out_row[i] = xn * weight[i];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same above

Comment on lines 268 to +269
int hidden_size = input.size(-1);
int input_stride = input.stride(-2);
int64_t input_stride = input.stride(-2);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add some comments here why we need to update to int64

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants