Skip to content

Conversation

bbeckca
Copy link
Contributor

@bbeckca bbeckca commented Aug 10, 2025

Summary:
• Vectorize writes when in/out phases align; scalar reads for weight when needed
• Use an optional fp16 shared-mem cache (avoids second read)
• Preserve numerics on odd/strided shapes by keeping scalar fallback
• Match eunfused launch/reduction order

Test Plan:

  1. Run tests
(vllm-env) (base) bbeckca@instance-20250803-153508:~/vllm$ pytest tests/kernels/core/test_layernorm.py
================================ test session starts =================================
platform linux -- Python 3.10.15, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/bbeckca/vllm
configfile: pyproject.toml
plugins: anyio-4.10.0
collected 1584 items                                                                 

tests/kernels/core/test_layernorm.py ......................................... [  2%]
.............................................................................. [  7%]
.............................................................................. [ 12%]
.............................................................................. [ 17%]
.............................................................................. [ 22%]
.............................................................................. [ 27%]
.............................................................................. [ 32%]
.............................................................................. [ 37%]
.............................................................................. [ 41%]
.............................................................................. [ 46%]
.............................................................................. [ 51%]
.............................................................................. [ 56%]
.............................................................................. [ 61%]
.............................................................................. [ 66%]
.............................................................................. [ 71%]
.............................................................................. [ 76%]
.............................................................................. [ 81%]
.............................................................................. [ 86%]
.............................................................................. [ 91%]
.............................................................................. [ 96%]
.............................................................                  [100%]

========================== 1584 passed in 643.00s (0:10:43) ==========================
(vllm-env) (base) bbeckca@instance-20250803-153508:~/vllm$ python -m pytest tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine -v
==================================== test session starts ====================================
platform linux -- Python 3.10.15, pytest-8.4.1, pluggy-1.6.0 -- /home/bbeckca/vllm-env/bin/python
cachedir: .pytest_cache
rootdir: /home/bbeckca/vllm
configfile: pyproject.toml
plugins: anyio-4.10.0
collected 2 items                                                                           

tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine[Qwen/Qwen3-1.7B] PASSED [ 50%]
tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine[google/gemma-3-1b-it] PASSED [100%]
  1. Run benchmark
buck run :benchmark_layernorm -- --num-tokens 16384 --hidden-size 1024 --dtype half --num-iters 500

Performance (selected, non-strided / HND)

T H dtype baseline (µs) current (µs) Δ
4096 1024 fp16 24.623 16.958 -31.1%
16384 1024 fp16 106.138 41.896 -60.5%
4096 8192 fp16 118.410 95.885 -19.0%
16384 8192 fp16 449.317 355.356 -20.9%
4096 1024 bf16 24.689 17.623 -28.6%
16384 1024 bf16 106.355 55.205 -48.1%
4096 8192 bf16 119.420 96.174 -19.5%
16384 8192 bf16 449.268 354.678 -21.1%

Strided (NHD) overhead (current kernel) — penalty vs. HND (same T/H/dtype):

  • 4096×1024 fp16: 1.40× (23.700 / 16.958)
  • 16384×1024 fp16: 2.13× (89.433 / 41.896)
  • 4096×8192 fp16: 1.94× (185.850 / 95.885)

Differential Revision: D79969610

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D79969610

@mergify mergify bot added the performance Performance-related issues label Aug 10, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly refactors the RMSNorm CUDA kernels to improve performance through vectorization, shared memory caching, and unified reduction logic. The benchmark results show a substantial speedup. The changes are well-structured and the optimizations follow common CUDA best practices. However, I found a critical integer overflow bug in the quantized fused kernel. While the input stride was correctly updated to int64_t to support large tensors, the corresponding index variables were left as int, which could lead to incorrect memory access and corrupted results.

Comment on lines 17 to 32
template <typename T>
__device__ __forceinline__ T warp_sum(T v) {
#ifdef __HIP_PLATFORM_AMD__
const unsigned long long m = 0xffffffffffffffffull; // HIP needs 64-bit mask
#else
const unsigned m = 0xffffffffu; // CUDA 32-bit mask
#endif
// Always reduce over 32 lanes to match downstream logic.
constexpr int kWidth = 32;
v += __shfl_down_sync(m, v, 16, kWidth);
v += __shfl_down_sync(m, v, 8, kWidth);
v += __shfl_down_sync(m, v, 4, kWidth);
v += __shfl_down_sync(m, v, 2, kWidth);
v += __shfl_down_sync(m, v, 1, kWidth);
return v;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: warp_sum and ln_block_threads_unified are the same for rms_norm_kernel and rms_norm_static_fp8_quant. Will move to shared helper after aligning on high level approach.

@mgoin
Copy link
Member

mgoin commented Aug 12, 2025

cc @yewentao256

Copy link
Collaborator

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!

Could we also test for end to end performance & accuracy?

For accuracy: lm_eval ...

For E2E performance (optional, nice to have): vllm bench ...

Comment on lines 259 to 236
// -------- Fallback scalar --------
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]);
scalar_t xn = static_cast<scalar_t>(x * inv_rms);
out_row[i] = xn * weight[i];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kept the scalar fallback for now. The vectorization util (even V=1) peels a prefix and changes per-thread write order, which causes small FP8 mismatches vs the fused path on unaligned shapes. Happy to revisit once fused ordering matches or revise further if you have any preferences.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scalar fallback inside the util is actually for some address misalignment, coping with some pre-align and tail align. Could you give me an example of mismatches? And we can see together whether we should update the util itself.

@bbeckca
Copy link
Contributor Author

bbeckca commented Aug 13, 2025

Thanks for the work!

Could we also test for end to end performance & accuracy?

For accuracy: lm_eval ...

For E2E performance (optional, nice to have): vllm bench ...

Thanks for the feedback @yewentao256. Will work on addressing and testing e2e.

bbeckca added a commit to bbeckca/vllm that referenced this pull request Aug 16, 2025
Summary:

What: Make RMSNorm faster by reading data in bigger aligned chunks, caching fp16 rows in shared memory, making the FP8-quant version work with strided inputs, and using the same launch settings as the unfused path. 

Why: Cut global memory traffic using aligned vector inputs/outputs and shared-mem reuse (avoids second read), make the FP8 path safe for strided inputs, and preserve numerics by matching the unfused reduction/launch order.

Test Plan:
1) Run tests

```
[[email protected] /data/users/benjibeck/fbsource/fbcode/vllm (1043a27694)]$ buck2 test :test_kernels_layernorm
Buck UI: https://www.internalfb.com/buck2/054ebad3-ad92-4676-a4d2-3bf43e44f31a
Test UI: https://www.internalfb.com/intern/testinfra/testrun/10414574240710255
Network: Up: 152MiB  Down: 2.9GiB  (reSessionID-14af330c-26bf-41d5-87b0-5775bf7d6f8a)
Loading targets.   Remaining     0/7                                                                                           150 dirs read, 69 targets declared
Analyzing targets. Remaining     0/32                                                                                          772 actions, 819 artifacts declared
Executing actions. Remaining     0/245                                                                                         48.3s exec time total
Command: test.     Finished 1 local, 14 remote, 131 cache (90% hit)                                                            45.2s exec time cached (93%)                                         
Time elapsed: 4:53.4s
Tests finished: Pass 3169. Fail 0. Fatal 0. Skip 0. Build failure 0
```

2) Run benchmark
```
buck run :benchmark_layernorm -- --num-tokens 16384 --hidden-size 1024 --dtype half --num-iters 500
```

Performance (selected, non-strided / HND)

| T | H | dtype | baseline (µs) | current (µs) | Δ |
| 4096  | 1024 | fp16 | 24.592  | 16.552  | -32.7% |
| 16384 | 1024 | fp16 | 106.699 | 42.739  | -60.0% |
| 4096  | 8192 | fp16 | 118.566 | 97.059  | -18.1% |
| 16384 | 8192 | fp16 | 450.738 | 356.125 | -21.0% |
| 4096  | 1024 | bf16 | 24.743  | 16.683  | -32.6% |
| 16384 | 1024 | bf16 | 107.009 | 56.946  | -46.8% |
| 4096  | 8192 | bf16 | 119.293 | 96.774  | -18.9% |
| 16384 | 8192 | bf16 | 451.181 | 357.799 | -20.7% |

Strided (NHD) overhead (current kernel) — penalty vs. HND (same T/H/dtype):
- 4096×1024 fp16: 1.39× (22.983 / 16.552)
- 16384×1024 fp16: 2.13× (90.995 / 42.739)
- 4096×8192 fp16: 1.93× (186.931 / 97.059)

Rollback Plan:

Differential Revision: D79969610
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D79969610

bbeckca added a commit to bbeckca/vllm that referenced this pull request Aug 16, 2025
Summary:

What: Make RMSNorm faster by reading data in bigger aligned chunks, caching fp16 rows in shared memory, making the FP8-quant version work with strided inputs, and using the same launch settings as the unfused path. 

Why: Cut global memory traffic using aligned vector inputs/outputs and shared-mem reuse (avoids second read), make the FP8 path safe for strided inputs, and preserve numerics by matching the unfused reduction/launch order.

Test Plan:
1) Run tests

```
[[email protected] /data/users/benjibeck/fbsource/fbcode/vllm (1043a27694)]$ buck2 test :test_kernels_layernorm
Buck UI: https://www.internalfb.com/buck2/054ebad3-ad92-4676-a4d2-3bf43e44f31a
Test UI: https://www.internalfb.com/intern/testinfra/testrun/10414574240710255
Network: Up: 152MiB  Down: 2.9GiB  (reSessionID-14af330c-26bf-41d5-87b0-5775bf7d6f8a)
Loading targets.   Remaining     0/7                                                                                           150 dirs read, 69 targets declared
Analyzing targets. Remaining     0/32                                                                                          772 actions, 819 artifacts declared
Executing actions. Remaining     0/245                                                                                         48.3s exec time total
Command: test.     Finished 1 local, 14 remote, 131 cache (90% hit)                                                            45.2s exec time cached (93%)                                         
Time elapsed: 4:53.4s
Tests finished: Pass 3169. Fail 0. Fatal 0. Skip 0. Build failure 0
```

2) Run benchmark
```
buck run :benchmark_layernorm -- --num-tokens 16384 --hidden-size 1024 --dtype half --num-iters 500
```

Performance (selected, non-strided / HND)

| T | H | dtype | baseline (µs) | current (µs) | Δ |
| 4096  | 1024 | fp16 | 24.592  | 16.552  | -32.7% |
| 16384 | 1024 | fp16 | 106.699 | 42.739  | -60.0% |
| 4096  | 8192 | fp16 | 118.566 | 97.059  | -18.1% |
| 16384 | 8192 | fp16 | 450.738 | 356.125 | -21.0% |
| 4096  | 1024 | bf16 | 24.743  | 16.683  | -32.6% |
| 16384 | 1024 | bf16 | 107.009 | 56.946  | -46.8% |
| 4096  | 8192 | bf16 | 119.293 | 96.774  | -18.9% |
| 16384 | 8192 | bf16 | 451.181 | 357.799 | -20.7% |

Strided (NHD) overhead (current kernel) — penalty vs. HND (same T/H/dtype):
- 4096×1024 fp16: 1.39× (22.983 / 16.552)
- 16384×1024 fp16: 2.13× (90.995 / 42.739)
- 4096×8192 fp16: 1.93× (186.931 / 97.059)

Rollback Plan:

Differential Revision: D79969610
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D79969610

Summary:

What: Make RMSNorm faster by reading data in bigger aligned chunks, caching fp16 rows in shared memory, making the FP8-quant version work with strided inputs, and using the same launch settings as the unfused path.

Why: Cut global memory traffic using aligned vector inputs/outputs and shared-mem reuse (avoids second read), make the FP8 path safe for strided inputs, and preserve numerics by matching the unfused reduction/launch order.

Test Plan:
1) Run tests

```
[[email protected] /data/users/benjibeck/fbsource/fbcode/vllm (1043a27694)]$ buck2 test :test_kernels_layernorm
Buck UI: https://www.internalfb.com/buck2/054ebad3-ad92-4676-a4d2-3bf43e44f31a
Test UI: https://www.internalfb.com/intern/testinfra/testrun/10414574240710255
Network: Up: 152MiB  Down: 2.9GiB  (reSessionID-14af330c-26bf-41d5-87b0-5775bf7d6f8a)
Loading targets.   Remaining     0/7                                                                                           150 dirs read, 69 targets declared
Analyzing targets. Remaining     0/32                                                                                          772 actions, 819 artifacts declared
Executing actions. Remaining     0/245                                                                                         48.3s exec time total
Command: test.     Finished 1 local, 14 remote, 131 cache (90% hit)                                                            45.2s exec time cached (93%)
Time elapsed: 4:53.4s
Tests finished: Pass 3169. Fail 0. Fatal 0. Skip 0. Build failure 0
```

2) Run benchmark
```
buck run :benchmark_layernorm -- --num-tokens 16384 --hidden-size 1024 --dtype half --num-iters 500
```

Performance (selected, non-strided / HND)

| T | H | dtype | baseline (µs) | current (µs) | Δ |
| 4096  | 1024 | fp16 | 24.592  | 16.552  | -32.7% |
| 16384 | 1024 | fp16 | 106.699 | 42.739  | -60.0% |
| 4096  | 8192 | fp16 | 118.566 | 97.059  | -18.1% |
| 16384 | 8192 | fp16 | 450.738 | 356.125 | -21.0% |
| 4096  | 1024 | bf16 | 24.743  | 16.683  | -32.6% |
| 16384 | 1024 | bf16 | 107.009 | 56.946  | -46.8% |
| 4096  | 8192 | bf16 | 119.293 | 96.774  | -18.9% |
| 16384 | 8192 | bf16 | 451.181 | 357.799 | -20.7% |

Strided (NHD) overhead (current kernel) — penalty vs. HND (same T/H/dtype):
- 4096×1024 fp16: 1.39× (22.983 / 16.552)
- 16384×1024 fp16: 2.13× (90.995 / 42.739)
- 4096×8192 fp16: 1.93× (186.931 / 97.059)

Rollback Plan:

Differential Revision: D79969610

Signed-off-by: Benji Beck <[email protected]>
Signed-off-by: Benji Beck <[email protected]>
@bbeckca
Copy link
Contributor Author

bbeckca commented Aug 17, 2025

Update: Performed e2e benchmark for latency using x2 H100 and observed ~4.8% reduction with vectorized rms_norm_kernel without shared memory. I'm planning to benchmark further with other models/configurations to understand this better. Let me know if you folks have any thoughts/preferences.

VLLM_USE_V1=1 buck run //vllm:benchmark_latency --   --model facebook/opt-125m   --input-len 32 --output-len 128   --dtype float16 --batch-size 128 --enforce-eager

baseline
Avg latency: 0.6731001150328666 seconds
10% percentile latency: 0.6315004411269911 seconds
25% percentile latency: 0.6374467767454917 seconds
50% percentile latency: 0.6609808525245171 seconds
75% percentile latency: 0.6663343247637386 seconds
90% percentile latency: 0.6783564629906322 seconds
99% percentile latency: 1.0606678462994754 seconds

vectorized
Avg latency: 0.6741364902312247 seconds
10% percentile latency: 0.6488532381190453 seconds
25% percentile latency: 0.6496474424784537 seconds
50% percentile latency: 0.6550397415121552 seconds
75% percentile latency: 0.6627388229971984 seconds
90% percentile latency: 0.666680197842652 seconds
99% percentile latency: 1.039106485189987 seconds

vectorized (without shared memory)
Avg latency: 0.6409522683670124 seconds
10% percentile latency: 0.6146808989578858 seconds
25% percentile latency: 0.6191042929858668 seconds
50% percentile latency: 0.6249049424950499 seconds
75% percentile latency: 0.6411205744952895 seconds
90% percentile latency: 0.6505832085269503 seconds
99% percentile latency: 0.8902701860625533 seconds

@bbeckca
Copy link
Contributor Author

bbeckca commented Aug 17, 2025

Ran quick E2E latency on a few models (bs=128, in=32, out=128). Overall: small win on Llama-3.2, neutral on Qwen fp16, and Gemma-2-9B works better if we skip the shared-mem cached row. I can rerun with further settings or add other models if there's any preferences.

@yewentao256 Wondering if we should consider adding a knob to disable the cached-row path where it regresses for some models?

Model dtype baseline (s) vectorized (s) Δ vs base no-shared-mem (s) Δ vs base
Llama-3.2-3B-Instruct fp16 1.2895 1.2698 −1.5% 1.2863 −0.3%
Llama-3.2-3B-Instruct bf16 1.2772 1.2663 −0.9% 1.2892 +0.9%
Qwen2-7B fp16 1.7673 1.7692 +0.1% 1.7679 +0.0%
Qwen2-7B bf16 1.7307 1.7143 −0.9% 1.6929 −2.2%
Gemma-2-9B bf16 4.1187 4.2222 +2.5% 4.0485 −1.7%

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: These changes were made for testing within Meta infra and will be deleted before landing.

@bbeckca
Copy link
Contributor Author

bbeckca commented Aug 20, 2025

Ran quick E2E latency on a few models (bs=128, in=32, out=128). Overall: small win on Llama-3.2, neutral on Qwen fp16, and Gemma-2-9B works better if we skip the shared-mem cached row. I can rerun with further settings or add other models if there's any preferences.

@yewentao256 Wondering if we should consider adding a knob to disable the cached-row path where it regresses for some models?

Model dtype baseline (s) vectorized (s) Δ vs base no-shared-mem (s) Δ vs base
Llama-3.2-3B-Instruct fp16 1.2895 1.2698 −1.5% 1.2863 −0.3%
Llama-3.2-3B-Instruct bf16 1.2772 1.2663 −0.9% 1.2892 +0.9%
Qwen2-7B fp16 1.7673 1.7692 +0.1% 1.7679 +0.0%
Qwen2-7B bf16 1.7307 1.7143 −0.9% 1.6929 −2.2%
Gemma-2-9B bf16 4.1187 4.2222 +2.5% 4.0485 −1.7%

cc @mgoin @WoosukKwon for any additional thoughts.

@bbeckca
Copy link
Contributor Author

bbeckca commented Aug 23, 2025

Ran quick E2E latency on a few models (bs=128, in=32, out=128). Overall: small win on Llama-3.2, neutral on Qwen fp16, and Gemma-2-9B works better if we skip the shared-mem cached row. I can rerun with further settings or add other models if there's any preferences.

@yewentao256 Wondering if we should consider adding a knob to disable the cached-row path where it regresses for some models?

Model dtype baseline (s) vectorized (s) Δ vs base no-shared-mem (s) Δ vs base
Llama-3.2-3B-Instruct fp16 1.2895 1.2698 −1.5% 1.2863 −0.3%
Llama-3.2-3B-Instruct bf16 1.2772 1.2663 −0.9% 1.2892 +0.9%
Qwen2-7B fp16 1.7673 1.7692 +0.1% 1.7679 +0.0%
Qwen2-7B bf16 1.7307 1.7143 −0.9% 1.6929 −2.2%
Gemma-2-9B bf16 4.1187 4.2222 +2.5% 4.0485 −1.7%

Hi @yewentao256, just following up in case you have time to review. Would be great if we could align on these results before making further changes. Thanks!

Copy link
Collaborator

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late response, there are some problems with my github notifications but now fixed.
The overall work I think is good, could you

  1. test the accuracy using lm_eval?
  2. benchmark the kernel directly?
    One example could be seen here: #22036

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants