-
-
Notifications
You must be signed in to change notification settings - Fork 9.4k
Update rms_norm_kernel by removing redundant global memory loads #22134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to optimize the rms_norm_kernel
by caching the input tensor in shared memory to avoid a second read from global memory. While the optimization strategy is sound, the implementation introduces a critical bug. A fixed-size shared memory array of 1024 elements is used, but the hidden_size
can be larger. This leads to out-of-bounds memory access, which can cause kernel crashes or silent data corruption. My review provides details on the issue and suggests using dynamic shared memory to correctly handle variable hidden_size
.
csrc/layernorm_kernels.cu
Outdated
@@ -21,10 +20,14 @@ __global__ void rms_norm_kernel( | |||
const scalar_t* __restrict__ weight, // [hidden_size] | |||
const float epsilon, const int num_tokens, const int hidden_size) { | |||
__shared__ float s_variance; | |||
__shared__ float s_input[1024]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shared memory array s_input
is allocated with a fixed size of 1024 elements. However, the hidden_size
of the model can be larger than this, as indicated by the test cases (e.g., 5120, 8192).
In the loop starting on line 28, the index idx
can go up to hidden_size - 1
. If hidden_size > 1024
, accessing s_input[idx]
will result in an out-of-bounds write to shared memory. This is undefined behavior and will likely lead to race conditions, incorrect outputs, or kernel crashes.
To support an arbitrary hidden_size
, you should use dynamic shared memory. This would involve:
- Declaring shared memory with
extern __shared__ float s_mem[];
in the kernel. - Calculating the required shared memory size in the host function
rms_norm
based onhidden_size
and other shared memory requirements (s_variance
,reduceStore
). - Passing this size as the third parameter in the kernel launch configuration
<<<...>>>
. - Manually managing pointers into the
s_mem
buffer fors_input
,s_variance
, andreduceStore
inside the kernel.
Since the host-side changes are outside of this diff, I cannot provide a complete code suggestion. However, the current implementation is critically flawed for any hidden_size > 1024
.
922403e
to
3b78127
Compare
Signed-off-by: Benji Beck <[email protected]>
3b78127
to
303987a
Compare
I'll take a closer look at the failing tests, meanwhile feel free to provide any feedback on the overall approach. |
Same here, please provide some benchmarks. cc @WoosukKwon |
Signed-off-by: Benji Beck <[email protected]>
e04ac2f
to
5360a4d
Compare
Just added benchmark results to the test plan, ran 16384 x 1024, 500 iters, and perf was neutral (~394 µs between branches). No speedup, but removes a redundant read. Let me know if you have a different config to profile in mind. |
Also cc @mgoin |
Hi all, just following up. I’ve posted a prototype (#22602) that vectorizes RMSNorm and makes the FP8-quant path stride-safe while matching the unfused kernel’s reduction/launch order. All tests are passing in test_layernorm.py. Benchmark on (16384×1024, fp16, 500 iters): ~106 µs → ~42.6 µs per call. I’ll add more benchmarks across sizes/dtypes/strides. All feedback is welcome! |
Purpose
This PR updates rms_norm_kernel by removing a redundant global memory load. It caches the input into shared memory during the variance computation pass and reuses it during normalization. While this should reduce memory bandwidth usage and improve efficiency in theory, runtime benchmarks show neutral impact.
Test Plan
Test Result