Skip to content

Conversation

@yewentao256
Copy link
Member

@yewentao256 yewentao256 commented Jan 6, 2026

Purpose

Part of the #31755

Here we add a kernel for faster calculation of problem size

Test

export MODEL="zai-org/GLM-4.7-FP8"

vllm serve $MODEL -tp 8 --port 9256 --enable-expert-parallel --max_num_seqs 128

Acc

lm_eval --model local-completions --model_args "base_url=http://127.0.0.1:9256/v1/completions,model=$MODEL,num_concurrent=1024" --tasks gsm8k

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.9469|±  |0.0062|
|     |       |strict-match    |     5|exact_match||0.9462|±  |0.0062|

With EPLB:

vllm serve $MODEL -tp 8 --port 9256 --enable-expert-parallel --max_num_seqs 128 --enable-eplb

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.9462|±  |0.0062|
|     |       |strict-match    |     5|exact_match||0.9454|±  |0.0063|

Perf

vllm bench serve --model $MODEL --dataset-name random --host 127.0.0.1 --port 9256 --random-input-len 2 --random-output-len 128 --request-rate inf --num-prompts 128

# Now
============ Serving Benchmark Result ============
Successful requests:                     128       
Failed requests:                         0         
Benchmark duration (s):                  4.79      
Total input tokens:                      256       
Total generated tokens:                  16384     
Request throughput (req/s):              26.71     
Output token throughput (tok/s):         3419.28   
Peak output token throughput (tok/s):    3584.00   
Peak concurrent requests:                128.00    
Total token throughput (tok/s):          3472.71   
---------------Time to First Token----------------
Mean TTFT (ms):                          165.77    
Median TTFT (ms):                        158.45    
P99 TTFT (ms):                           187.77    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          36.17     
Median TPOT (ms):                        36.17     
P99 TPOT (ms):                           36.22     
---------------Inter-token Latency----------------
Mean ITL (ms):                           36.17     
Median ITL (ms):                         36.17     
P99 ITL (ms):                            39.66     
==================================================
# main 
============ Serving Benchmark Result ============
Successful requests:                     128       
Failed requests:                         0         
Benchmark duration (s):                  5.05      
Total input tokens:                      256       
Total generated tokens:                  16384     
Request throughput (req/s):              25.37     
Output token throughput (tok/s):         3246.74   
Peak output token throughput (tok/s):    3456.00   
Peak concurrent requests:                128.00    
Total token throughput (tok/s):          3297.47   
---------------Time to First Token----------------
Mean TTFT (ms):                          169.53    
Median TTFT (ms):                        160.36    
P99 TTFT (ms):                           196.47    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          38.14     
Median TPOT (ms):                        38.17     
P99 TPOT (ms):                           38.20     
---------------Inter-token Latency----------------
Mean ITL (ms):                           38.14     
Median ITL (ms):                         38.08     
P99 ITL (ms):                            42.53     
==================================================

Note

Introduces a faster path for CUTLASS MoE problem-size setup and wires it through C++/Python to fused MoE.

  • New CUDA kernel + entry points: compute_problem_sizes_from_expert_offsets with callers and registry (moe_data.cu, scaled_mm_entry.cu, torch_bindings.cpp) and Python API ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets
  • Replace topk_ids counting with expert-offset–based sizing in cutlass_moe.py (both FP8 and W4A8 paths): allocate problem_sizes{1,2} for local_E, call the new op, derive expert_offsets from expert_first_token_offset, and always pass expert_first_token_offset to moe_unpermute
  • Swap-AB control via size heuristic or forced for RS GEMM; use VLLM_DISPATCH_BOOL for launch-time specialization; minor pointer/type cleanups
  • Public headers updated (ops.h) and Python wrappers added in _custom_ops.py

Written by Cursor Bugbot for commit 2c089da. This will update automatically on new commits. Configure here.


Note

Cursor Bugbot is generating a summary for commit 1902481. Configure here.

Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization for calculating Mixture-of-Experts (MoE) problem sizes in CUTLASS kernels. It replaces a kernel that computes expert token counts from topk_ids with a more efficient kernel that derives them from expert_first_token_offset. This change is propagated through the C++ ops, Python bindings, and the cutlass_moe.py layer, resulting in a notable end-to-end throughput improvement. The changes are logical and well-implemented. I have one critical comment regarding a potential integer overflow that could lead to correctness issues.

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 6, 2026
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a good opportunity for a Triton kernel, did you try that?

yewentao256 and others added 3 commits January 7, 2026 17:59
Signed-off-by: yewentao256 <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Signed-off-by: Wentao Ye <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
@yewentao256
Copy link
Member Author

yewentao256 commented Jan 7, 2026

This seems like a good opportunity for a Triton kernel, did you try that?

No yet, why for a triton kernel?

@mgoin
Copy link
Member

mgoin commented Jan 7, 2026

This seems like a good opportunity for a Triton kernel, did you try that?

No yet, why for a triton kernel?

I agree, a triton kernel would be simpler and could be completely local to cutlass_moe.py

@yewentao256
Copy link
Member Author

yewentao256 commented Jan 8, 2026

@ProExpertProg @mgoin I wrote a Triton kernel and it makes TTFT slower

============ Serving Benchmark Result ============
Successful requests:                     128       
Failed requests:                         0         
Benchmark duration (s):                  4.78      
Total input tokens:                      256       
Total generated tokens:                  16384     
Request throughput (req/s):              26.77     
Output token throughput (tok/s):         3426.12   
Peak output token throughput (tok/s):    3584.00   
Peak concurrent requests:                128.00    
Total token throughput (tok/s):          3479.65   
---------------Time to First Token----------------
Mean TTFT (ms):                          172.63    
Median TTFT (ms):                        165.87    
P99 TTFT (ms):                           199.96    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          36.01     
Median TPOT (ms):                        36.04     
P99 TPOT (ms):                           36.08     
---------------Inter-token Latency----------------
Mean ITL (ms):                           36.01     
Median ITL (ms):                         35.96     
P99 ITL (ms):                            40.37     
==================================================

I think we don't need to involve triton in this PR, if needed we can have following up PR and tune it accordingly

Comment on lines -111 to -115
if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(
expert_map[topk_ids] != -1, expert_map[topk_ids], -1
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please verify correctness for removing the expert_map logic here. I assume this works because moe_permute already handles the mapping, but I'm not sure. I think you should test accuracy with EP and EPLB to properly exercise this case

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.9462|±  |0.0062|
|     |       |strict-match    |     5|exact_match||0.9454|±  |0.0063|

Tested with EPLB, added in the PR description as well

Signed-off-by: yewentao256 <[email protected]>
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM!

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Jan 9, 2026
Comment on lines +175 to +176
ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
expert_first_token_offset, problem_sizes1, problem_sizes2, N, K, swap_ab
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we use it or cutlass moe fp4 too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems out of this PR's scope, I can test it and if could be used, I will have a following up PR for that.

@simon-mo simon-mo merged commit 308feab into main Jan 9, 2026
98 of 100 checks passed
@simon-mo simon-mo deleted the wentao-optimize-cutlass-moe branch January 9, 2026 19:13
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Jan 9, 2026
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
…hput improvement, 2.2% TTFT improvement (vllm-project#31830)

Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: Wentao Ye <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…hput improvement, 2.2% TTFT improvement (vllm-project#31830)

Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: Wentao Ye <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Signed-off-by: dsuhinin <[email protected]>
aykoppol pushed a commit to aykoppol/vllm that referenced this pull request Jan 21, 2026
…hput improvement, 2.2% TTFT improvement (vllm-project#31830)

Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: Wentao Ye <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
daje0601 pushed a commit to daje0601/vllm that referenced this pull request Jan 22, 2026
…hput improvement, 2.2% TTFT improvement (vllm-project#31830)

Signed-off-by: yewentao256 <[email protected]>
Signed-off-by: Wentao Ye <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Signed-off-by: daje0601 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants