Skip to content

Conversation

tdoublep
Copy link
Member

@tdoublep tdoublep commented Aug 10, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This PR removes the --enforce-eager constraint for Minimax models. It adds support for piecewise CUDA graphs for the linear attention and enables torch compiling of the rest of the model.

It would be great if Minimax team could run additional correctness checks on the real model.

cc @rogeryoungh @qscqesze @heheda12345

Test Plan

I have tested it using Goekdeniz-Guelmez/MiniMax01Text-Dev locally. I haven't included that test in this PR because we need to land #21549 before it can be included because FlashInfer doesn't support that tiny model unfortunately.

Test Result

The test is passing (e.g., V1 results with compile match V0 results).

(Optional) Documentation Update

Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MiniMax-Text model to enable torch.compile and piecewise CUDA graph capture. The changes primarily involve modifying forward passes to use output buffers instead of returning tensors, which is a key pattern for compiler compatibility. A custom op linear_attention is introduced to serve as a boundary for piecewise compilation. The changes are generally well-executed and align with the goal of improving performance through compilation. My feedback focuses on improving code quality by correcting type hints and removing a leftover debug statement.

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! @tdoublep can you update the document?
Please merge after the correctness is verified.

@heheda12345 heheda12345 added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 11, 2025
@rogeryoungh
Copy link

rogeryoungh commented Aug 14, 2025

Thank you for your work! We ran validation using the same build environment as last time, but the results seemed a bit unusual, which could be due to a problem with the model inference process. Here are the reproduction steps and outcomes for your reference.

We compiled and installed your PR inside the vllm-openai:v0.10.0 Docker image via pip install --no-build-isolation /tmp/vllm_patched/.

The deployment command was the same as in the previous PR, just with --enforce-eager removed:

python3 -m vllm.entrypoints.api_server --model /data/xxx/model/MiniMax-Text-01/ --tensor-parallel-size 8 --trust-remote-code --quantization experts_int8 --max_model_len 8192 --dtype bfloat16 --no-enable-prefix-caching

Here are the test results:

For gsm8k:

python3 bench_other.py --num-questions 500 --num-shots 5 --backend vllm --port 8000 --host http://127.0.0.1
# ...
Accuracy: 0.010
Invalid: 0.018
Latency: 202.115 s

For mmlu:

python3 bench_other.py --nsub 200 --backend vllm --port 8000 --host http://127.0.0.1
# ...
Total latency: 1405.775
Average accuracy: 0.801

Upon checking the model's output for GSM8K, we noticed a significant number of extra newlines and garbled characters, indicating an abnormal output format.

On the other hand, for the MMLU benchmark, which only requires a single letter as the answer, the accuracy is only slightly lower than normal. We suspect this simpler output format might be masking some underlying issues that are more apparent in the GSM8K results.

@tdoublep
Copy link
Member Author

@rogeryoungh Thanks for the eval. There must be some bug that isn't being hit when I run with the tiny model. I will take another look at it.

Copy link

mergify bot commented Aug 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tdoublep.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 15, 2025
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
@tdoublep
Copy link
Member Author

I was able to reproduce the bad lm_eval results and dug into what is going on here.

The problem is related to implementation of the rotary embedding for this model (it is not compatible with torch compile). I've replaced it with the call to get_rope that the other models use. If there was any particular reason why a custom implementation (e.g., MiniMaxText01RotaryEmbedding) was needed, please let me know. I took a quick look through the code and couldn't see one.

I deploy the model as follows:

VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASHINFER vllm serve MiniMaxAI/MiniMax-Text-01 \
	--tensor-parallel-size 8 \
	--trust-remote-code \
	--quantization experts_int8  \
	--max_model_len 4096 \
	--dtype bfloat16 \
	--gpu-memory-utilization 0.95 \
	--no-enable-prefix-caching

and then I run eval with:

lm_eval   --model local-completions   \
	--model_args base_url=http://localhost:8000/v1/completions,tokenizer=MiniMaxAI/MiniMax-Text-01 \
	--tasks gsm8k  \
	--batch_size 128 \
	--num_fewshot 5 \
	--limit 500

which now produces the expected output:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.898|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.892|±  |0.0139|

@rogeryoungh Could you give it another try on your end?

@tdoublep
Copy link
Member Author

After pulling in latest changes from main, we can also now deploy with the default FlashAttention backend (instead of FlashInfer):

VLLM_USE_V1=1 vllm serve MiniMaxAI/MiniMax-Text-01 \
	--tensor-parallel-size 8 \
	--trust-remote-code \
	--quantization experts_int8  \
	--max_model_len 4096 \
	--dtype bfloat16 \
	--gpu-memory-utilization 0.95 \
	--no-enable-prefix-caching

The gsm8k eval above now produces:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.904|±  |0.0132|
|     |       |strict-match    |     5|exact_match|↑  |0.900|±  |0.0134|

Signed-off-by: Thomas Parnell <[email protected]>
@rogeryoungh
Copy link

Great work! I have verified the changes, and the implementation now works as expected. On GSM8k the accuracy is 0.908, and on MMLU the average accuracy is 0.847.

Deployment command:

VLLM_ATTENTION_BACKEND=FLASHINFER VLLM_USE_V1=1 python3 -m vllm.entrypoints.api_server \
  --model /data/xxx/model/MiniMax-Text-01/ \
  --tensor-parallel-size 8 \
  --trust-remote-code \
  --quantization experts_int8 \
  --max_model_len 4096 \
  --dtype bfloat16 \
  --no-enable-prefix-caching

GSM8k test:

Accuracy: 0.908
Invalid: 0.000
Latency: 203.588 s

MMLU test:

Total latency: 1489.795
Average accuracy: 0.847

@qscqesze
Copy link
Contributor

I was able to reproduce the bad lm_eval results and dug into what is going on here.

The problem is related to implementation of the rotary embedding for this model (it is not compatible with torch compile). I've replaced it with the call to get_rope that the other models use. If there was any particular reason why a custom implementation (e.g., MiniMaxText01RotaryEmbedding) was needed, please let me know. I took a quick look through the code and couldn't see one.

Thanks a lot for the fix! The custom implementation didn’t have any special reason — it was just how it was written back then. Really appreciate your improvement.

@rogeryoungh
Copy link

I also retested with the default FlashAttention backend. On GSM8k, the accuracy was 0.904, and on MMLU the average accuracy was 0.847. Everything looks good now.

@heheda12345 heheda12345 merged commit dd58932 into vllm-project:main Aug 27, 2025
40 checks passed
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
dumb0002 pushed a commit to dumb0002/vllm that referenced this pull request Aug 28, 2025
2015aroras pushed a commit to 2015aroras/vllm that referenced this pull request Aug 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants