Skip to content

Conversation

@luccafong
Copy link
Collaborator

@luccafong luccafong commented Sep 15, 2025

Purpose

  • Support torchrun DP/EP with MOE models
  • Add CI tests for MOE models on torchrun

Test Plan/Results

Simple Example

torchrun --nproc-per-node=2 examples/offline_inference/torchrun_dp_example.py

lm_eval

Need patch lm_eval PR EleutherAI/lm-evaluation-harness#3304

torchrun --nproc-per-node=8 --no-python  lm_eval     --model vllm     --model_args "pretrained=/data/local/models/oss/DeepSeek-R1-0528,max_model_len=20000,gpu_memory_utilization=0.9,tensor_parallel_size=1,data_parallel_size=8,enable_expert_parallel=true,max_num_seqs=256,distributed_executor_backend=external_launcher"     --batch_size 256     --task gsm8k
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.956|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.953|±  |0.0058|

baseline of non torchrun version

lm_eval --model vllm --model_args "pretrained=/data/local/models/oss/DeepSeek-R1-0528,max_model_len=20000,gpu_memory_utilization=0.9,tensor_parallel_size=8,data_parallel_size=1,enable_expert_parallel=true,max_num_seqs=256" --batch_size 256 --task gsm8k
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9568|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.9545|±  |0.0057|

Added CI tests

  - TP_SIZE=4 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
  # test with torchrun tp=2, pp=2 and dp=1
  - PP_SIZE=2 TP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
  # test with torchrun tp=1 and dp=4 with ep
  - DP_SIZE=4 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
  # test with torchrun tp=2 and dp=2 with ep
  - TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added documentation Improvements or additions to documentation v1 labels Sep 15, 2025
@luccafong luccafong force-pushed the torchrun_dp branch 2 times, most recently from 3f09d97 to 521eeab Compare September 15, 2025 19:54
@mergify
Copy link

mergify bot commented Sep 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @luccafong.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@zhuohan123 zhuohan123 disabled auto-merge September 22, 2025 17:54
@zhuohan123
Copy link
Member

Will force merge since the CI failure is not caused by this PR and is being fixed by #25396

@zhuohan123 zhuohan123 merged commit 922979b into vllm-project:main Sep 22, 2025
76 of 78 checks passed
@facebook-github-bot
Copy link

@kingsmad has imported this pull request. If you are a Meta employee, you can view this in D82998295.

@facebook-github-bot
Copy link

This pull request has been imported. If you are a Meta employee, you can view this in D82998295.

FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
…roject#24899)

Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Zhuohan Li <[email protected]>
Co-authored-by: Zhuohan Li <[email protected]>
Signed-off-by: charlifu <[email protected]>
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Zhuohan Li <[email protected]>
Co-authored-by: Zhuohan Li <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
gjc0824 pushed a commit to gjc0824/vllm that referenced this pull request Oct 10, 2025
…roject#24899)

Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Zhuohan Li <[email protected]>
Co-authored-by: Zhuohan Li <[email protected]>
Signed-off-by: gaojc <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…roject#24899)

Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Zhuohan Li <[email protected]>
Co-authored-by: Zhuohan Li <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…roject#24899)

Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Zhuohan Li <[email protected]>
Co-authored-by: Zhuohan Li <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants