-
-
Notifications
You must be signed in to change notification settings - Fork 9.3k
[V1] [Hybrid] Enable compile and piecewise CUDA graph for MiniMax-Text models #22589
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
tdoublep
wants to merge
5
commits into
vllm-project:main
Choose a base branch
from
tdoublep:minimax-compile-pr
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+75
−36
Open
Changes from 2 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,6 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
"""Inference-only MiniMaxText01 model.""" | ||
import copy | ||
import math | ||
from collections.abc import Iterable | ||
from typing import Optional, Union | ||
|
@@ -16,12 +15,13 @@ | |
|
||
from vllm import envs | ||
from vllm.attention import Attention, AttentionMetadata | ||
from vllm.compilation.decorators import support_torch_compile | ||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config | ||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce | ||
from vllm.distributed.parallel_state import ( | ||
get_pp_group, get_tensor_model_parallel_rank, | ||
get_tensor_model_parallel_world_size) | ||
from vllm.forward_context import get_forward_context | ||
from vllm.forward_context import ForwardContext, get_forward_context | ||
from vllm.model_executor.custom_op import CustomOp | ||
from vllm.model_executor.layers.activation import SiluAndMul | ||
from vllm.model_executor.layers.fused_moe import FusedMoE | ||
|
@@ -44,7 +44,9 @@ | |
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
from vllm.model_executor.models.utils import maybe_prefix | ||
from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
from vllm.platforms import current_platform | ||
from vllm.sequence import IntermediateTensors | ||
from vllm.utils import direct_register_custom_op | ||
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata | ||
|
||
from .interfaces import HasInnerState, IsHybrid | ||
|
@@ -507,20 +509,40 @@ def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, | |
slot_id, 32) | ||
return hidden | ||
|
||
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, | ||
kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor: | ||
qkv, _ = self.qkv_proj(hidden_states) | ||
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: MinimaxCacheParams) -> torch.Tensor: | ||
if not envs.VLLM_USE_V1: | ||
self._forward(hidden_states, output, positions, kv_caches) | ||
else: | ||
torch.ops.vllm.linear_attention( | ||
hidden_states, | ||
output, | ||
positions, | ||
self.prefix, | ||
) | ||
|
||
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: MinimaxCacheParams) -> torch.Tensor: | ||
forward_context = get_forward_context() | ||
attn_metadata: AttentionMetadata = forward_context.attn_metadata | ||
if envs.VLLM_USE_V1 and attn_metadata is not None: | ||
assert isinstance(attn_metadata, dict) | ||
attn_metadata = attn_metadata[self.prefix] | ||
assert isinstance(attn_metadata, LinearAttentionMetadata) | ||
num_actual_tokens = attn_metadata.num_prefill_tokens + \ | ||
attn_metadata.num_decode_tokens | ||
else: | ||
num_actual_tokens = hidden_states.shape[0] | ||
|
||
qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens]) | ||
qkv32 = qkv.to(torch.float32) | ||
qkvact = torch.nn.functional.silu(qkv32) | ||
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) | ||
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) | ||
forward_context = get_forward_context() | ||
attn_metadata = forward_context.attn_metadata | ||
if envs.VLLM_USE_V1: | ||
if attn_metadata is not None: | ||
assert isinstance(attn_metadata, dict) | ||
attn_metadata = attn_metadata[self.prefix] | ||
assert isinstance(attn_metadata, LinearAttentionMetadata) | ||
kv_cache = self.kv_cache[forward_context.virtual_engine][0] | ||
state_indices_tensor = attn_metadata.state_indices_tensor | ||
|
||
|
@@ -559,13 +581,11 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, | |
hidden = self._decode_infer(q, k, v, kv_cache, | ||
state_indices_tensor, | ||
attn_metadata) | ||
|
||
hidden = self.norm._forward(hidden) | ||
gate, _ = self.output_gate(hidden_states) | ||
gate, _ = self.output_gate(hidden_states[:num_actual_tokens]) | ||
hidden = F.sigmoid(gate) * hidden | ||
hidden = hidden.to(hidden_states.dtype) | ||
hidden, _ = self.out_proj(hidden) | ||
return hidden | ||
output[:num_actual_tokens], _ = self.out_proj(hidden) | ||
|
||
|
||
class MiniMaxText01Attention(nn.Module): | ||
|
@@ -635,8 +655,8 @@ def __init__( | |
) | ||
return | ||
|
||
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, | ||
**kwargs) -> torch.Tensor: | ||
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, | ||
positions: torch.Tensor, **kwargs) -> None: | ||
forward_context = get_forward_context() | ||
attn_metadata = forward_context.attn_metadata | ||
qkv, _ = self.qkv_proj(hidden_states) | ||
|
@@ -648,8 +668,7 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, | |
else: | ||
q, k = attn_metadata.rotary_emb(positions, q, k) | ||
attn_output = self.attn(q, k, v) | ||
output, _ = self.o_proj(attn_output) | ||
return output | ||
output[:], _ = self.o_proj(attn_output) | ||
|
||
|
||
class MiniMaxText01DecoderLayer(nn.Module): | ||
|
@@ -794,16 +813,15 @@ def forward(self, | |
is_warmup: bool = False, | ||
**kwargs) -> tuple[torch.Tensor, torch.Tensor]: | ||
|
||
forward_context = get_forward_context() | ||
attn_metadata = forward_context.attn_metadata | ||
layernorm_input = hidden_states | ||
layernorm_output = self.input_layernorm(layernorm_input) | ||
residual = layernorm_output if self.postnorm else layernorm_input | ||
self_attention_output = self.self_attn( | ||
self_attention_output = torch.empty_like(layernorm_output) | ||
self.self_attn( | ||
hidden_states=layernorm_output, | ||
output=self_attention_output, | ||
positions=positions, | ||
kv_caches=kv_caches, | ||
attn_metadata=attn_metadata, | ||
) | ||
|
||
residual = residual * self.layernorm_attention_alpha | ||
|
@@ -817,8 +835,8 @@ def forward(self, | |
if self.expert_num == 1: | ||
hidden_states = self.mlp(layernorm_output) | ||
else: | ||
moe_hidden_states = self.block_sparse_moe( | ||
copy.deepcopy(layernorm_output)) | ||
moe_layernorm_output = layernorm_output.clone() | ||
moe_hidden_states = self.block_sparse_moe(moe_layernorm_output) | ||
if self.shared_moe: | ||
before_moe_dtype = layernorm_output.dtype | ||
moe_hidden_fp32 = moe_hidden_states.to(torch.float32) | ||
|
@@ -856,17 +874,15 @@ def shared_moe_coefficient_loader(param: torch.Tensor, | |
return | ||
|
||
|
||
@support_torch_compile | ||
class MiniMaxText01Model(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
config: MiniMaxConfig, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
cache_config: Optional[CacheConfig] = None, | ||
scheduler_config=None, | ||
prefix: str = "", | ||
) -> None: | ||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
super().__init__() | ||
config: MiniMaxConfig = vllm_config.model_config.hf_config | ||
cache_config = vllm_config.cache_config | ||
quant_config = vllm_config.quant_config | ||
scheduler_config = vllm_config.scheduler_config | ||
|
||
self.padding_idx = config.pad_token_id | ||
self.vocab_size = config.vocab_size | ||
|
@@ -1019,12 +1035,11 @@ def forward(self, | |
attn_metadata = forward_context.attn_metadata | ||
if not envs.VLLM_USE_V1 and attn_metadata is None: | ||
return None | ||
if "request_ids_to_seq_ids" not in kwargs: | ||
kwargs["request_ids_to_seq_ids"] = {} | ||
if "finished_requests_ids" not in kwargs: | ||
kwargs["finished_requests_ids"] = [] | ||
|
||
if not envs.VLLM_USE_V1: | ||
if "request_ids_to_seq_ids" not in kwargs: | ||
kwargs["request_ids_to_seq_ids"] = {} | ||
if "finished_requests_ids" not in kwargs: | ||
kwargs["finished_requests_ids"] = [] | ||
( | ||
minimax_cache_tensors, | ||
state_indices_tensor, | ||
|
@@ -1096,7 +1111,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: | |
|
||
super().__init__() | ||
config = vllm_config.model_config.hf_config | ||
quant_config = vllm_config.quant_config | ||
lora_config = vllm_config.lora_config | ||
self.config = config | ||
self.lora_config = lora_config | ||
|
@@ -1109,12 +1123,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: | |
self.unpadded_vocab_size = self.config.vocab_size | ||
if hasattr(vllm_config.model_config, "max_model_len"): | ||
self.config.max_model_len = vllm_config.model_config.max_model_len | ||
self.model = MiniMaxText01Model( | ||
self.config, | ||
quant_config, | ||
cache_config=vllm_config.cache_config, | ||
scheduler_config=vllm_config.scheduler_config, | ||
prefix=maybe_prefix(prefix, "model")) | ||
self.model = MiniMaxText01Model(vllm_config=vllm_config, | ||
prefix=maybe_prefix(prefix, "model")) | ||
if get_pp_group().is_last_rank: | ||
self.lm_head = ParallelLMHead( | ||
self.unpadded_vocab_size, | ||
|
@@ -1433,3 +1443,36 @@ def get_mamba_state_shape_from_config( | |
tp_size=parallel_config.tensor_parallel_size, | ||
head_dim=hf_config.head_dim, | ||
) | ||
|
||
|
||
def linear_attention( | ||
hidden_states: torch.Tensor, | ||
output: torch.Tensor, | ||
positions: torch.Tensor, | ||
layer_name: str, | ||
) -> None: | ||
forward_context: ForwardContext = get_forward_context() | ||
print("layer_name: ", layer_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
self = forward_context.no_compile_layers[layer_name] | ||
self._forward(hidden_states=hidden_states, | ||
output=output, | ||
positions=positions, | ||
kv_caches=None) | ||
|
||
|
||
def linear_attention_fake( | ||
hidden_states: torch.Tensor, | ||
output: torch.Tensor, | ||
positions: torch.Tensor, | ||
layer_name: str, | ||
) -> None: | ||
return | ||
|
||
|
||
direct_register_custom_op( | ||
op_name="linear_attention", | ||
op_func=linear_attention, | ||
mutates_args=["output"], | ||
fake_impl=linear_attention_fake, | ||
dispatch_key=current_platform.dispatch_key, | ||
) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hints for
forward
and_forward
methods inMiniMaxText01LinearAttention
have some issues that should be corrected for code clarity and correctness:forward
(L514) and_forward
(L527) is annotated astorch.Tensor
, but neither function returns a value. They should be annotated with-> None
.kv_caches
parameter in_forward
(L527) is annotated asMinimaxCacheParams
, but it's called withNone
from thelinear_attention
custom op (L1460). It should beOptional[MinimaxCacheParams]
.