From 9ad6271231a0ee26181bd131f807aafa01a57670 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sun, 10 Aug 2025 05:26:47 -0400 Subject: [PATCH 1/5] Enable compile for minimax Signed-off-by: Thomas Parnell --- vllm/config/compilation.py | 1 + vllm/model_executor/models/minimax_text_01.py | 132 ++++++++++++------ 2 files changed, 89 insertions(+), 44 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 8a78d811b9a2..f6ffe4505331 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -425,4 +425,5 @@ def set_splitting_ops_for_v1(self): "vllm.unified_attention", "vllm.unified_attention_with_output", "vllm.mamba_mixer2", + "vllm.linear_attention", ] diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 3d14a6ad5c3a..2cf52b0e5a44 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only MiniMaxText01 model.""" -import copy import math from collections.abc import Iterable from typing import Optional, Union @@ -16,12 +15,13 @@ from vllm import envs from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.forward_context import get_forward_context +from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -44,7 +44,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata from .interfaces import HasInnerState, IsHybrid @@ -507,20 +509,41 @@ def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, slot_id, 32) return hidden - def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) + def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, + positions: torch.Tensor, + kv_caches: MinimaxCacheParams) -> torch.Tensor: + if not envs.VLLM_USE_V1: + self._forward(hidden_states, output, positions, kv_caches) + else: + torch.ops.vllm.linear_attention( + hidden_states, + output, + positions, + self.prefix, + ) + + def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, + positions: torch.Tensor, + kv_caches: MinimaxCacheParams) -> torch.Tensor: + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if envs.VLLM_USE_V1 and attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, LinearAttentionMetadata) + num_actual_tokens = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens + else: + num_actual_tokens = hidden_states.shape[0] + + qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens]) qkv32 = qkv.to(torch.float32) qkvact = torch.nn.functional.silu(qkv32) qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata + if envs.VLLM_USE_V1: if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - assert isinstance(attn_metadata, LinearAttentionMetadata) kv_cache = self.kv_cache[forward_context.virtual_engine][0] state_indices_tensor = attn_metadata.state_indices_tensor @@ -559,13 +582,11 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, hidden = self._decode_infer(q, k, v, kv_cache, state_indices_tensor, attn_metadata) - hidden = self.norm._forward(hidden) - gate, _ = self.output_gate(hidden_states) + gate, _ = self.output_gate(hidden_states[:num_actual_tokens]) hidden = F.sigmoid(gate) * hidden hidden = hidden.to(hidden_states.dtype) - hidden, _ = self.out_proj(hidden) - return hidden + output[:num_actual_tokens], _ = self.out_proj(hidden) class MiniMaxText01Attention(nn.Module): @@ -635,8 +656,8 @@ def __init__( ) return - def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - **kwargs) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, + positions: torch.Tensor, **kwargs) -> None: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata qkv, _ = self.qkv_proj(hidden_states) @@ -648,8 +669,7 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, else: q, k = attn_metadata.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output + output[:], _ = self.o_proj(attn_output) class MiniMaxText01DecoderLayer(nn.Module): @@ -794,16 +814,15 @@ def forward(self, is_warmup: bool = False, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata layernorm_input = hidden_states layernorm_output = self.input_layernorm(layernorm_input) residual = layernorm_output if self.postnorm else layernorm_input - self_attention_output = self.self_attn( + self_attention_output = torch.empty_like(layernorm_output) + self.self_attn( hidden_states=layernorm_output, + output=self_attention_output, positions=positions, kv_caches=kv_caches, - attn_metadata=attn_metadata, ) residual = residual * self.layernorm_attention_alpha @@ -817,8 +836,8 @@ def forward(self, if self.expert_num == 1: hidden_states = self.mlp(layernorm_output) else: - moe_hidden_states = self.block_sparse_moe( - copy.deepcopy(layernorm_output)) + moe_layernorm_output = layernorm_output.clone() + moe_hidden_states = self.block_sparse_moe(moe_layernorm_output) if self.shared_moe: before_moe_dtype = layernorm_output.dtype moe_hidden_fp32 = moe_hidden_states.to(torch.float32) @@ -856,17 +875,15 @@ def shared_moe_coefficient_loader(param: torch.Tensor, return +@support_torch_compile class MiniMaxText01Model(nn.Module): - def __init__( - self, - config: MiniMaxConfig, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - scheduler_config=None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config: MiniMaxConfig = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + scheduler_config = vllm_config.scheduler_config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -1019,12 +1036,11 @@ def forward(self, attn_metadata = forward_context.attn_metadata if not envs.VLLM_USE_V1 and attn_metadata is None: return None - if "request_ids_to_seq_ids" not in kwargs: - kwargs["request_ids_to_seq_ids"] = {} - if "finished_requests_ids" not in kwargs: - kwargs["finished_requests_ids"] = [] - if not envs.VLLM_USE_V1: + if "request_ids_to_seq_ids" not in kwargs: + kwargs["request_ids_to_seq_ids"] = {} + if "finished_requests_ids" not in kwargs: + kwargs["finished_requests_ids"] = [] ( minimax_cache_tensors, state_indices_tensor, @@ -1096,7 +1112,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config @@ -1109,12 +1124,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.unpadded_vocab_size = self.config.vocab_size if hasattr(vllm_config.model_config, "max_model_len"): self.config.max_model_len = vllm_config.model_config.max_model_len - self.model = MiniMaxText01Model( - self.config, - quant_config, - cache_config=vllm_config.cache_config, - scheduler_config=vllm_config.scheduler_config, - prefix=maybe_prefix(prefix, "model")) + self.model = MiniMaxText01Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( self.unpadded_vocab_size, @@ -1433,3 +1444,36 @@ def get_mamba_state_shape_from_config( tp_size=parallel_config.tensor_parallel_size, head_dim=hf_config.head_dim, ) + + +def linear_attention( + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + print("layer_name: ", layer_name) + self = forward_context.no_compile_layers[layer_name] + self._forward(hidden_states=hidden_states, + output=output, + positions=positions, + kv_caches=None) + + +def linear_attention_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="linear_attention", + op_func=linear_attention, + mutates_args=["output"], + fake_impl=linear_attention_fake, + dispatch_key=current_platform.dispatch_key, +) From c698db3280c6f756ff36086a14931a77acd43d66 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sun, 10 Aug 2025 05:31:52 -0400 Subject: [PATCH 2/5] minor diff reduction Signed-off-by: Thomas Parnell --- vllm/model_executor/models/minimax_text_01.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 2cf52b0e5a44..6def7947c060 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -541,7 +541,6 @@ def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, qkvact = torch.nn.functional.silu(qkv32) qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) - if envs.VLLM_USE_V1: if attn_metadata is not None: kv_cache = self.kv_cache[forward_context.virtual_engine][0] From bf57695ae7a55f7ef2247dbd03cbec6453980cf3 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sun, 10 Aug 2025 08:37:32 -0400 Subject: [PATCH 3/5] Try to use return instead of mutate Signed-off-by: Thomas Parnell --- vllm/model_executor/models/minimax_text_01.py | 44 ++++++++----------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 6def7947c060..e2d0c11aa654 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -509,21 +509,18 @@ def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, slot_id, 32) return hidden - def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, - positions: torch.Tensor, + def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, kv_caches: MinimaxCacheParams) -> torch.Tensor: if not envs.VLLM_USE_V1: - self._forward(hidden_states, output, positions, kv_caches) + return self._forward(hidden_states, positions, kv_caches) else: - torch.ops.vllm.linear_attention( + return torch.ops.vllm.linear_attention( hidden_states, - output, positions, self.prefix, ) - def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, - positions: torch.Tensor, + def _forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, kv_caches: MinimaxCacheParams) -> torch.Tensor: forward_context = get_forward_context() attn_metadata: AttentionMetadata = forward_context.attn_metadata @@ -585,7 +582,8 @@ def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, gate, _ = self.output_gate(hidden_states[:num_actual_tokens]) hidden = F.sigmoid(gate) * hidden hidden = hidden.to(hidden_states.dtype) - output[:num_actual_tokens], _ = self.out_proj(hidden) + output, _ = self.out_proj(hidden) + return output[:num_actual_tokens] class MiniMaxText01Attention(nn.Module): @@ -655,8 +653,8 @@ def __init__( ) return - def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, - positions: torch.Tensor, **kwargs) -> None: + def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, + **kwargs) -> None: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata qkv, _ = self.qkv_proj(hidden_states) @@ -668,7 +666,8 @@ def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, else: q, k = attn_metadata.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) - output[:], _ = self.o_proj(attn_output) + output, _ = self.o_proj(attn_output) + return output class MiniMaxText01DecoderLayer(nn.Module): @@ -816,10 +815,8 @@ def forward(self, layernorm_input = hidden_states layernorm_output = self.input_layernorm(layernorm_input) residual = layernorm_output if self.postnorm else layernorm_input - self_attention_output = torch.empty_like(layernorm_output) - self.self_attn( + self_attention_output = self.self_attn( hidden_states=layernorm_output, - output=self_attention_output, positions=positions, kv_caches=kv_caches, ) @@ -1447,32 +1444,29 @@ def get_mamba_state_shape_from_config( def linear_attention( hidden_states: torch.Tensor, - output: torch.Tensor, positions: torch.Tensor, layer_name: str, -) -> None: +) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() - print("layer_name: ", layer_name) self = forward_context.no_compile_layers[layer_name] - self._forward(hidden_states=hidden_states, - output=output, - positions=positions, - kv_caches=None) + output = self._forward(hidden_states=hidden_states, + positions=positions, + kv_caches=None) + return output def linear_attention_fake( hidden_states: torch.Tensor, - output: torch.Tensor, positions: torch.Tensor, layer_name: str, -) -> None: - return +) -> torch.tensor: + return torch.empty_like(hidden_states) direct_register_custom_op( op_name="linear_attention", op_func=linear_attention, - mutates_args=["output"], + mutates_args=[], fake_impl=linear_attention_fake, dispatch_key=current_platform.dispatch_key, ) From c57d39e81fb55f9bf21ec934c047f85677119a40 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sun, 10 Aug 2025 08:39:51 -0400 Subject: [PATCH 4/5] Fix return type Signed-off-by: Thomas Parnell --- vllm/model_executor/models/minimax_text_01.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index e2d0c11aa654..339fbdfd8c03 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -654,7 +654,7 @@ def __init__( return def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - **kwargs) -> None: + **kwargs) -> torch.Tensor: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata qkv, _ = self.qkv_proj(hidden_states) From 97fdef6908c09153ddfdb6dfc446bde0774f33d3 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sun, 10 Aug 2025 08:41:30 -0400 Subject: [PATCH 5/5] minor cleanup Signed-off-by: Thomas Parnell --- vllm/model_executor/models/minimax_text_01.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 339fbdfd8c03..dce6a619bff7 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -578,12 +578,13 @@ def _forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, hidden = self._decode_infer(q, k, v, kv_cache, state_indices_tensor, attn_metadata) + hidden = self.norm._forward(hidden) gate, _ = self.output_gate(hidden_states[:num_actual_tokens]) hidden = F.sigmoid(gate) * hidden hidden = hidden.to(hidden_states.dtype) - output, _ = self.out_proj(hidden) - return output[:num_actual_tokens] + hidden, _ = self.out_proj(hidden) + return hidden[:num_actual_tokens] class MiniMaxText01Attention(nn.Module):