Skip to content

[V1] [Hybrid] Enable compile and piecewise CUDA graph for MiniMax-Text models #22589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

tdoublep
Copy link
Member

@tdoublep tdoublep commented Aug 10, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This PR removes the --enforce-eager constraint for Minimax models. It adds support for piecewise CUDA graphs for the linear attention and enables torch compiling of the rest of the model.

It would be great if Minimax team could run additional correctness checks on the real model.

cc @rogeryoungh @qscqesze @heheda12345

Test Plan

I have tested it using Goekdeniz-Guelmez/MiniMax01Text-Dev locally. I haven't included that test in this PR because we need to land #21549 before it can be included because FlashInfer doesn't support that tiny model unfortunately.

Test Result

The test is passing (e.g., V1 results with compile match V0 results).

(Optional) Documentation Update

Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MiniMax-Text model to enable torch.compile and piecewise CUDA graph capture. The changes primarily involve modifying forward passes to use output buffers instead of returning tensors, which is a key pattern for compiler compatibility. A custom op linear_attention is introduced to serve as a boundary for piecewise compilation. The changes are generally well-executed and align with the goal of improving performance through compilation. My feedback focuses on improving code quality by correcting type hints and removing a leftover debug statement.

Comment on lines 512 to 527
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: MinimaxCacheParams) -> torch.Tensor:
if not envs.VLLM_USE_V1:
self._forward(hidden_states, output, positions, kv_caches)
else:
torch.ops.vllm.linear_attention(
hidden_states,
output,
positions,
self.prefix,
)

def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: MinimaxCacheParams) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The type hints for forward and _forward methods in MiniMaxText01LinearAttention have some issues that should be corrected for code clarity and correctness:

  1. The return type for both forward (L514) and _forward (L527) is annotated as torch.Tensor, but neither function returns a value. They should be annotated with -> None.
  2. The kv_caches parameter in _forward (L527) is annotated as MinimaxCacheParams, but it's called with None from the linear_attention custom op (L1460). It should be Optional[MinimaxCacheParams].
Suggested change
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: MinimaxCacheParams) -> torch.Tensor:
if not envs.VLLM_USE_V1:
self._forward(hidden_states, output, positions, kv_caches)
else:
torch.ops.vllm.linear_attention(
hidden_states,
output,
positions,
self.prefix,
)
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: MinimaxCacheParams) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: MinimaxCacheParams) -> None:
if not envs.VLLM_USE_V1:
self._forward(hidden_states, output, positions, kv_caches)
else:
torch.ops.vllm.linear_attention(
hidden_states,
output,
positions,
self.prefix,
)
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[MinimaxCacheParams]) -> None:

layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
print("layer_name: ", layer_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This print statement appears to be a leftover from debugging. It should be removed to avoid polluting logs in production.

Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! @tdoublep can you update the document?
Please merge after the correctness is verified.

@heheda12345 heheda12345 added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants