diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 8a78d811b9a2..f6ffe4505331 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -425,4 +425,5 @@ def set_splitting_ops_for_v1(self): "vllm.unified_attention", "vllm.unified_attention_with_output", "vllm.mamba_mixer2", + "vllm.linear_attention", ] diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 3d14a6ad5c3a..dce6a619bff7 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only MiniMaxText01 model.""" -import copy import math from collections.abc import Iterable from typing import Optional, Union @@ -16,12 +15,13 @@ from vllm import envs from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.forward_context import get_forward_context +from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -44,7 +44,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata from .interfaces import HasInnerState, IsHybrid @@ -508,19 +510,36 @@ def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, return hidden def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, - kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) + kv_caches: MinimaxCacheParams) -> torch.Tensor: + if not envs.VLLM_USE_V1: + return self._forward(hidden_states, positions, kv_caches) + else: + return torch.ops.vllm.linear_attention( + hidden_states, + positions, + self.prefix, + ) + + def _forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, + kv_caches: MinimaxCacheParams) -> torch.Tensor: + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if envs.VLLM_USE_V1 and attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, LinearAttentionMetadata) + num_actual_tokens = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens + else: + num_actual_tokens = hidden_states.shape[0] + + qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens]) qkv32 = qkv.to(torch.float32) qkvact = torch.nn.functional.silu(qkv32) qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1)) q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1) - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata if envs.VLLM_USE_V1: if attn_metadata is not None: - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - assert isinstance(attn_metadata, LinearAttentionMetadata) kv_cache = self.kv_cache[forward_context.virtual_engine][0] state_indices_tensor = attn_metadata.state_indices_tensor @@ -561,11 +580,11 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, attn_metadata) hidden = self.norm._forward(hidden) - gate, _ = self.output_gate(hidden_states) + gate, _ = self.output_gate(hidden_states[:num_actual_tokens]) hidden = F.sigmoid(gate) * hidden hidden = hidden.to(hidden_states.dtype) hidden, _ = self.out_proj(hidden) - return hidden + return hidden[:num_actual_tokens] class MiniMaxText01Attention(nn.Module): @@ -794,8 +813,6 @@ def forward(self, is_warmup: bool = False, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata layernorm_input = hidden_states layernorm_output = self.input_layernorm(layernorm_input) residual = layernorm_output if self.postnorm else layernorm_input @@ -803,7 +820,6 @@ def forward(self, hidden_states=layernorm_output, positions=positions, kv_caches=kv_caches, - attn_metadata=attn_metadata, ) residual = residual * self.layernorm_attention_alpha @@ -817,8 +833,8 @@ def forward(self, if self.expert_num == 1: hidden_states = self.mlp(layernorm_output) else: - moe_hidden_states = self.block_sparse_moe( - copy.deepcopy(layernorm_output)) + moe_layernorm_output = layernorm_output.clone() + moe_hidden_states = self.block_sparse_moe(moe_layernorm_output) if self.shared_moe: before_moe_dtype = layernorm_output.dtype moe_hidden_fp32 = moe_hidden_states.to(torch.float32) @@ -856,17 +872,15 @@ def shared_moe_coefficient_loader(param: torch.Tensor, return +@support_torch_compile class MiniMaxText01Model(nn.Module): - def __init__( - self, - config: MiniMaxConfig, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - scheduler_config=None, - prefix: str = "", - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config: MiniMaxConfig = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + scheduler_config = vllm_config.scheduler_config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -1019,12 +1033,11 @@ def forward(self, attn_metadata = forward_context.attn_metadata if not envs.VLLM_USE_V1 and attn_metadata is None: return None - if "request_ids_to_seq_ids" not in kwargs: - kwargs["request_ids_to_seq_ids"] = {} - if "finished_requests_ids" not in kwargs: - kwargs["finished_requests_ids"] = [] - if not envs.VLLM_USE_V1: + if "request_ids_to_seq_ids" not in kwargs: + kwargs["request_ids_to_seq_ids"] = {} + if "finished_requests_ids" not in kwargs: + kwargs["finished_requests_ids"] = [] ( minimax_cache_tensors, state_indices_tensor, @@ -1096,7 +1109,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config @@ -1109,12 +1121,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.unpadded_vocab_size = self.config.vocab_size if hasattr(vllm_config.model_config, "max_model_len"): self.config.max_model_len = vllm_config.model_config.max_model_len - self.model = MiniMaxText01Model( - self.config, - quant_config, - cache_config=vllm_config.cache_config, - scheduler_config=vllm_config.scheduler_config, - prefix=maybe_prefix(prefix, "model")) + self.model = MiniMaxText01Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( self.unpadded_vocab_size, @@ -1433,3 +1441,33 @@ def get_mamba_state_shape_from_config( tp_size=parallel_config.tensor_parallel_size, head_dim=hf_config.head_dim, ) + + +def linear_attention( + hidden_states: torch.Tensor, + positions: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + output = self._forward(hidden_states=hidden_states, + positions=positions, + kv_caches=None) + return output + + +def linear_attention_fake( + hidden_states: torch.Tensor, + positions: torch.Tensor, + layer_name: str, +) -> torch.tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="linear_attention", + op_func=linear_attention, + mutates_args=[], + fake_impl=linear_attention_fake, + dispatch_key=current_platform.dispatch_key, +)