Skip to content

Update rms_norm_kernel by removing redundant global memory loads #22134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

bbeckca
Copy link
Contributor

@bbeckca bbeckca commented Aug 3, 2025

Purpose

This PR updates rms_norm_kernel by removing a redundant global memory load. It caches the input into shared memory during the variance computation pass and reuses it during normalization. While this should reduce memory bandwidth usage and improve efficiency in theory, runtime benchmarks show neutral impact.

Test Plan

  • Run CI to validate correctness
  • Manually run benchmark script

Test Result

(vllm-env) (base) bbeckca@instance-20250803-153508:~/vllm$ git branch
  layernorm-memory
* main
(vllm-env) (base) bbeckca@instance-20250803-153508:~/vllm$ python benchmarks/kernels/benchmark_layernorm.py \
  --num-tokens 16384 \
  --hidden-size 1024 \
  --dtype half \
  --num-iters 500
INFO 08-03 20:41:30 [__init__.py:235] Automatically detected platform cuda.
Namespace(num_tokens=16384, hidden_size=1024, add_residual=False, dtype='half', seed=0, profile=False, num_warmup_iters=5, num_iters=500)
WARNING 08-03 20:41:30 [config.py:4898] Current vLLM config is not set.
WARNING 08-03 20:41:30 [config.py:4898] Current vLLM config is not set.
WARNING 08-03 20:41:30 [config.py:4898] Current vLLM config is not set.
Warming up...
Kernel running time: 394.061 us
(vllm-env) (base) bbeckca@instance-20250803-153508:~/vllm$ git checkout layernorm-memory 
Switched to branch 'layernorm-memory'
Your branch is up to date with 'origin/layernorm-memory'.
(vllm-env) (base) bbeckca@instance-20250803-153508:~/vllm$ python benchmarks/kernels/benchmark_layernorm.py   --num-tokens 16384   --hidden-size 1024   --dtype half   --num-iters 500
INFO 08-03 20:42:14 [__init__.py:235] Automatically detected platform cuda.
Namespace(num_tokens=16384, hidden_size=1024, add_residual=False, dtype='half', seed=0, profile=False, num_warmup_iters=5, num_iters=500)
WARNING 08-03 20:42:14 [config.py:4898] Current vLLM config is not set.
WARNING 08-03 20:42:14 [config.py:4898] Current vLLM config is not set.
WARNING 08-03 20:42:14 [config.py:4898] Current vLLM config is not set.
Warming up...
Kernel running time: 394.943 us
(vllm-env) (base) bbeckca@instance-20250803-153508:~/vllm$ 

Copy link

github-actions bot commented Aug 3, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to optimize the rms_norm_kernel by caching the input tensor in shared memory to avoid a second read from global memory. While the optimization strategy is sound, the implementation introduces a critical bug. A fixed-size shared memory array of 1024 elements is used, but the hidden_size can be larger. This leads to out-of-bounds memory access, which can cause kernel crashes or silent data corruption. My review provides details on the issue and suggests using dynamic shared memory to correctly handle variable hidden_size.

@@ -21,10 +20,14 @@ __global__ void rms_norm_kernel(
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
__shared__ float s_input[1024];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The shared memory array s_input is allocated with a fixed size of 1024 elements. However, the hidden_size of the model can be larger than this, as indicated by the test cases (e.g., 5120, 8192).

In the loop starting on line 28, the index idx can go up to hidden_size - 1. If hidden_size > 1024, accessing s_input[idx] will result in an out-of-bounds write to shared memory. This is undefined behavior and will likely lead to race conditions, incorrect outputs, or kernel crashes.

To support an arbitrary hidden_size, you should use dynamic shared memory. This would involve:

  1. Declaring shared memory with extern __shared__ float s_mem[]; in the kernel.
  2. Calculating the required shared memory size in the host function rms_norm based on hidden_size and other shared memory requirements (s_variance, reduceStore).
  3. Passing this size as the third parameter in the kernel launch configuration <<<...>>>.
  4. Manually managing pointers into the s_mem buffer for s_input, s_variance, and reduceStore inside the kernel.

Since the host-side changes are outside of this diff, I cannot provide a complete code suggestion. However, the current implementation is critically flawed for any hidden_size > 1024.

@bbeckca
Copy link
Contributor Author

bbeckca commented Aug 3, 2025

I'll take a closer look at the failing tests, meanwhile feel free to provide any feedback on the overall approach.

@DarkLight1337
Copy link
Member

Same here, please provide some benchmarks.

cc @WoosukKwon

@bbeckca bbeckca changed the title Optimize rms_norm_kernel by removing redundant global memory loads Update rms_norm_kernel by removing redundant global memory loads Aug 3, 2025
@bbeckca
Copy link
Contributor Author

bbeckca commented Aug 3, 2025

Same here, please provide some benchmarks.

cc @WoosukKwon

Just added benchmark results to the test plan, ran 16384 x 1024, 500 iters, and perf was neutral (~394 µs between branches). No speedup, but removes a redundant read. Let me know if you have a different config to profile in mind.

@DarkLight1337
Copy link
Member

Also cc @mgoin

@mgoin
Copy link
Member

mgoin commented Aug 5, 2025

Hi @bbeckca thanks for the idea. Would you want to try applying the vectorization util to this op? See this PR for most recent usage #22036

@bbeckca
Copy link
Contributor Author

bbeckca commented Aug 5, 2025

Hi @bbeckca thanks for the idea. Would you want to try applying the vectorization util to this op? See this PR for most recent usage #22036

Appreciate the suggestion! I’ll ramp up on vectorization over the next couple days and follow up once I start implementing changes.

@bbeckca
Copy link
Contributor Author

bbeckca commented Aug 10, 2025

Hi all, just following up. I’ve posted a prototype (#22602) that vectorizes RMSNorm and makes the FP8-quant path stride-safe while matching the unfused kernel’s reduction/launch order.

All tests are passing in test_layernorm.py. Benchmark on (16384×1024, fp16, 500 iters): ~106 µs → ~42.6 µs per call.

I’ll add more benchmarks across sizes/dtypes/strides. All feedback is welcome!

cc @DarkLight1337 @mgoin @WoosukKwon

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants