Skip to content

Fix Flashinfer Allreduce+Norm enable disable calculation based on fi_allreduce_fusion_max_token_num #21325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 22, 2025

Conversation

xinli-git
Copy link
Contributor

@xinli-git xinli-git commented Jul 21, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results

Purpose

Warm up for cudagraph passes all_reduce_in with shape [num_tokens, hidden size ] (for dsr1, hidden size is 7168), but the comparison compares this against: num_tokens * shape[0] (which is also num_tokens), leading to flashinfer fusion being always disabled with large enough hidden size

Test Result

No FI kernels being called on TOT.

VLLM_TORCH_PROFILER_DIR=./torch_profiler VLLM_DISABLE_COMPILE_CACHE=1 python3 vllm/benchmarks/benchmark_latency.py --model deepseek-ai/DeepSeek-R1  --trust-remote --load_format dummy  --gpu_memory_utilization=0.90 --max-num-seqs=1024  --tensor-parallel-size 4 --hf_overrides '{"num_hidden_layers": 10}'  --compilation-config='{"pass_config": {"enable_fi_allreduce_fusion": true}, "custom_ops": ["+rms_norm"], "level": 3, "debug_dump_path": "./debug_dsr1"}'  --no-enable-prefix-caching --profile

(VllmWorker rank=3 pid=3465) -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
(VllmWorker rank=3 pid=3465)                                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
(VllmWorker rank=3 pid=3465) -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
(VllmWorker rank=3 pid=3465)                                  _w8a8_block_fp8_matmul         0.00%       0.000us         0.00%       0.000us       0.000us     386.231ms        36.53%     386.231ms      49.772us          7760  
(VllmWorker rank=3 pid=3465) ncclDevKernel_AllReduce_Sum_bf16_RING_LL(ncclDevKern...         0.00%       0.000us         0.00%       0.000us       0.000us     250.732ms        23.71%     250.732ms      92.555us          2709  
(VllmWorker rank=3 pid=3465)                                        fused_moe_kernel         0.00%       0.000us         0.00%       0.000us       0.000us     140.873ms        13.32%     140.873ms      78.003us          1806  
(VllmWorker rank=3 pid=3465) void cutlass::Kernel2<cutlass_80_wmma_tensorop_bf16_...         0.00%       0.000us         0.00%       0.000us       0.000us      90.682ms         8.58%      90.682ms      89.166us          1017  
(VllmWorker rank=3 pid=3465)                                                aten::mm         0.27%       3.082ms         0.36%       4.154ms      32.205us      84.021ms         7.95%      84.021ms     651.322us           129  
(VllmWorker rank=3 pid=3465)                                      record_param_comms         0.51%       5.860ms         0.75%       8.612ms      33.381us      42.999ms         4.07%      42.999ms     166.662us           258  
(VllmWorker rank=3 pid=3465) ncclDevKernel_AllGather_RING_LL(ncclDevKernelArgsSto...         0.00%       0.000us         0.00%       0.000us       0.000us      42.999ms         4.07%      42.999ms     333.324us           129  
(VllmWorker rank=3 pid=3465)                                   nccl:_all_gather_base         0.00%       0.000us         0.00%       0.000us       0.000us      42.999ms         4.07%      42.999ms     333.324us           129  
(VllmWorker rank=3 pid=3465)                                               aten::bmm         5.61%      64.587ms         7.65%      88.019ms      34.383us      21.936ms         2.07%      22.053ms       8.614us          2560  
(VllmWorker rank=3 pid=3465)                                             aten::copy_         2.09%      24.047ms        16.56%     190.663ms      47.008us      19.095ms         1.81%      19.095ms       4.708us          4056  
(VllmWorker rank=3 pid=3465)                              _per_token_group_quant_fp8         0.00%       0.000us         0.00%       0.000us       0.000us      15.634ms         1.48%      15.634ms       1.634us          9566  
(VllmWorker rank=3 pid=3465)                          Memcpy HtoD (Pinned -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      13.231ms         1.25%      13.231ms      17.028us           777  
(VllmWorker rank=3 pid=3465)                     vllm::unified_attention_with_output        46.33%     533.359ms        72.98%     840.212ms     651.327us      11.806ms         1.12%      45.650ms      35.388us          1290  

With this PR, FI kernels are called correctly

VLLM_TORCH_PROFILER_DIR=./torch_profiler VLLM_DISABLE_COMPILE_CACHE=1 python3 vllm/benchmarks/benchmark_latency.py --model deepseek-ai/DeepSeek-R1  --trust-remote --load_format dummy  --gpu_memory_utilization=0.90 --max-num-seqs=1024  --tensor-parallel-size 4 --hf_overrides '{"num_hidden_layers": 10}'  --compilation-config='{"pass_config": {"enable_fi_allreduce_fusion": true}, "custom_ops": ["+rms_norm"], "level": 3, "debug_dump_path": "./debug_dsr1"}'  --no-enable-prefix-caching --profile

(VllmWorker rank=0 pid=284) -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
(VllmWorker rank=0 pid=284)                                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
(VllmWorker rank=0 pid=284) -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
(VllmWorker rank=0 pid=284) void flashinfer::trtllm_allreduce_fusion::allreduce_...         0.00%       0.000us         0.00%       0.000us       0.000us     407.540ms        31.55%     407.540ms     229.212us          1778  
(VllmWorker rank=0 pid=284)                                  _w8a8_block_fp8_matmul         0.00%       0.000us         0.00%       0.000us       0.000us     388.581ms        30.09%     388.581ms      50.075us          7760  
(VllmWorker rank=0 pid=284)                                        fused_moe_kernel         0.00%       0.000us         0.00%       0.000us       0.000us     141.444ms        10.95%     141.444ms      78.319us          1806  
(VllmWorker rank=0 pid=284) void cutlass::Kernel2<cutlass_80_wmma_tensorop_bf16_...         0.00%       0.000us         0.00%       0.000us       0.000us      91.447ms         7.08%      91.447ms      89.830us          1018  
(VllmWorker rank=0 pid=284)                                                aten::mm         0.25%       3.252ms         0.34%       4.440ms      34.420us      84.179ms         6.52%      84.179ms     652.553us           129  

Benchmarks (WIP)

On DSR1 + 4xB200 (with hf_overwrite to 30 layers), TP4, different concurrency values.

python3 vllm/benchmarks/benchmark_serving.py   --dataset-name random --max-concurrency $concurrency  --model deepseek-ai/DeepSeek-R1 --num-prompts $(( $concurrency * 5 )) --random-input-len 1024 --random-output-len 128
Concurrency	Mean TPOT (ms)_Fusion	Mean TTFT (ms)_Fusion	Mean TPOT (ms)_No_Fusion	Mean TTFT (ms)_No_Fusion
1	10.37	60.36	10.61	61.82
2	10.79	82.69	11.06	87.95
4	11.26	142.63	11.49	132.84
8	13.03	254.97	13.23	260.07
16	14.38	352.86	14.74	369.91
32	17.43	502.23	17.5	551.48
64	26.37	600.93	25.33	714.73
128	43.28	812.48	43.49	806.06
256	75.91	1226.19	76.44	1234.88
512	144.32	2138.59	144.74	2134.24

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a bug in the condition to enable FlashInfer's Allreduce+Norm fusion. The original logic incorrectly used num_tokens instead of hidden_size when calculating the dynamic size limit, causing the fusion to be disabled for models with large hidden sizes. The fix addresses this, and the provided test results confirm that the fusion is now correctly triggered. I've added one suggestion to improve the readability of the calculation.

@xinli-git xinli-git force-pushed the xinli/fix-fi-allreduce branch from f229bcc to 30b722e Compare July 21, 2025 20:48
@xinli-git xinli-git force-pushed the xinli/fix-fi-allreduce branch from 30b722e to bf3fa63 Compare July 21, 2025 20:50
@xinli-git xinli-git force-pushed the xinli/fix-fi-allreduce branch from bf247da to 8195e6c Compare July 21, 2025 21:29
@ilmarkov
Copy link
Contributor

ilmarkov commented Jul 22, 2025

Thank you for the fix! Looks good to me

@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed performance Performance-related issues labels Jul 22, 2025
@xinli-git
Copy link
Contributor Author

xinli-git commented Jul 22, 2025

Thanks @mgoin @ilmarkov , the CI failure seems to be unrelated and fails with recent vLLM PRs too, e.g. in https://buildkite.com/vllm/ci/builds/24587/steps/canvas?jid=01983248-1e76-4766-bd71-2881bbbe90b2

Please let me know if there is anything I can do to help merge this :)

@mgoin mgoin changed the title Fix Flashifner Allreduce+Norm enable disable calculation based on fi_allreduce_fusion_max_token_num Fix Flashinfer Allreduce+Norm enable disable calculation based on fi_allreduce_fusion_max_token_num Jul 22, 2025
@simon-mo simon-mo merged commit ae268b6 into vllm-project:main Jul 22, 2025
74 of 77 checks passed
yeqcharlotte pushed a commit to yeqcharlotte/vllm that referenced this pull request Jul 23, 2025
zixi-qi pushed a commit to zixi-qi/vllm that referenced this pull request Jul 23, 2025
…_allreduce_fusion_max_token_num` (vllm-project#21325)

Signed-off-by: XIn Li <[email protected]>
Signed-off-by: qizixi <[email protected]>
LyrisZhong pushed a commit to LyrisZhong/vllm that referenced this pull request Jul 23, 2025
avigny pushed a commit to avigny/vllm that referenced this pull request Jul 31, 2025
…_allreduce_fusion_max_token_num` (vllm-project#21325)

Signed-off-by: XIn Li <[email protected]>
Signed-off-by: avigny <[email protected]>
wenscarl pushed a commit to wenscarl/vllm that referenced this pull request Aug 4, 2025
…_allreduce_fusion_max_token_num` (vllm-project#21325)

Signed-off-by: XIn Li <[email protected]>
Signed-off-by: shuw <[email protected]>
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
…_allreduce_fusion_max_token_num` (vllm-project#21325)

Signed-off-by: XIn Li <[email protected]>
Signed-off-by: x22x22 <[email protected]>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…_allreduce_fusion_max_token_num` (vllm-project#21325)

Signed-off-by: XIn Li <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…_allreduce_fusion_max_token_num` (vllm-project#21325)

Signed-off-by: XIn Li <[email protected]>
Signed-off-by: Paul Pak <[email protected]>
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…_allreduce_fusion_max_token_num` (vllm-project#21325)

Signed-off-by: XIn Li <[email protected]>
Signed-off-by: Diego-Castan <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants