-
-
Notifications
You must be signed in to change notification settings - Fork 9.3k
[V1] [Hybrid] Enable compile and piecewise CUDA graph for MiniMax-Text models #22589
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the MiniMax-Text model to enable torch.compile
and piecewise CUDA graph capture. The changes primarily involve modifying forward passes to use output buffers instead of returning tensors, which is a key pattern for compiler compatibility. A custom op linear_attention
is introduced to serve as a boundary for piecewise compilation. The changes are generally well-executed and align with the goal of improving performance through compilation. My feedback focuses on improving code quality by correcting type hints and removing a leftover debug statement.
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: MinimaxCacheParams) -> torch.Tensor: | ||
if not envs.VLLM_USE_V1: | ||
self._forward(hidden_states, output, positions, kv_caches) | ||
else: | ||
torch.ops.vllm.linear_attention( | ||
hidden_states, | ||
output, | ||
positions, | ||
self.prefix, | ||
) | ||
|
||
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: MinimaxCacheParams) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hints for forward
and _forward
methods in MiniMaxText01LinearAttention
have some issues that should be corrected for code clarity and correctness:
- The return type for both
forward
(L514) and_forward
(L527) is annotated astorch.Tensor
, but neither function returns a value. They should be annotated with-> None
. - The
kv_caches
parameter in_forward
(L527) is annotated asMinimaxCacheParams
, but it's called withNone
from thelinear_attention
custom op (L1460). It should beOptional[MinimaxCacheParams]
.
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, | |
positions: torch.Tensor, | |
kv_caches: MinimaxCacheParams) -> torch.Tensor: | |
if not envs.VLLM_USE_V1: | |
self._forward(hidden_states, output, positions, kv_caches) | |
else: | |
torch.ops.vllm.linear_attention( | |
hidden_states, | |
output, | |
positions, | |
self.prefix, | |
) | |
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, | |
positions: torch.Tensor, | |
kv_caches: MinimaxCacheParams) -> torch.Tensor: | |
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, | |
positions: torch.Tensor, | |
kv_caches: MinimaxCacheParams) -> None: | |
if not envs.VLLM_USE_V1: | |
self._forward(hidden_states, output, positions, kv_caches) | |
else: | |
torch.ops.vllm.linear_attention( | |
hidden_states, | |
output, | |
positions, | |
self.prefix, | |
) | |
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, | |
positions: torch.Tensor, | |
kv_caches: Optional[MinimaxCacheParams]) -> None: |
layer_name: str, | ||
) -> None: | ||
forward_context: ForwardContext = get_forward_context() | ||
print("layer_name: ", layer_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! @tdoublep can you update the document?
Please merge after the correctness is verified.
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
This PR removes the
--enforce-eager
constraint for Minimax models. It adds support for piecewise CUDA graphs for the linear attention and enables torch compiling of the rest of the model.It would be great if Minimax team could run additional correctness checks on the real model.
cc @rogeryoungh @qscqesze @heheda12345
Test Plan
I have tested it using
Goekdeniz-Guelmez/MiniMax01Text-Dev
locally. I haven't included that test in this PR because we need to land #21549 before it can be included because FlashInfer doesn't support that tiny model unfortunately.Test Result
The test is passing (e.g., V1 results with compile match V0 results).
(Optional) Documentation Update