Skip to content

[V1] [Hybrid] Enable compile and piecewise CUDA graph for MiniMax-Text models #22589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,4 +425,5 @@ def set_splitting_ops_for_v1(self):
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.mamba_mixer2",
"vllm.linear_attention",
]
131 changes: 87 additions & 44 deletions vllm/model_executor/models/minimax_text_01.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only MiniMaxText01 model."""
import copy
import math
from collections.abc import Iterable
from typing import Optional, Union
Expand All @@ -16,12 +15,13 @@

from vllm import envs
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand All @@ -44,7 +44,9 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata

from .interfaces import HasInnerState, IsHybrid
Expand Down Expand Up @@ -507,20 +509,40 @@ def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
slot_id, 32)
return hidden

def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: MinimaxCacheParams) -> torch.Tensor:
if not envs.VLLM_USE_V1:
self._forward(hidden_states, output, positions, kv_caches)
else:
torch.ops.vllm.linear_attention(
hidden_states,
output,
positions,
self.prefix,
)

def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: MinimaxCacheParams) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The type hints for forward and _forward methods in MiniMaxText01LinearAttention have some issues that should be corrected for code clarity and correctness:

  1. The return type for both forward (L514) and _forward (L527) is annotated as torch.Tensor, but neither function returns a value. They should be annotated with -> None.
  2. The kv_caches parameter in _forward (L527) is annotated as MinimaxCacheParams, but it's called with None from the linear_attention custom op (L1460). It should be Optional[MinimaxCacheParams].
Suggested change
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: MinimaxCacheParams) -> torch.Tensor:
if not envs.VLLM_USE_V1:
self._forward(hidden_states, output, positions, kv_caches)
else:
torch.ops.vllm.linear_attention(
hidden_states,
output,
positions,
self.prefix,
)
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: MinimaxCacheParams) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: MinimaxCacheParams) -> None:
if not envs.VLLM_USE_V1:
self._forward(hidden_states, output, positions, kv_caches)
else:
torch.ops.vllm.linear_attention(
hidden_states,
output,
positions,
self.prefix,
)
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[MinimaxCacheParams]) -> None:

forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1 and attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata)
num_actual_tokens = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
num_actual_tokens = hidden_states.shape[0]

qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens])
qkv32 = qkv.to(torch.float32)
qkvact = torch.nn.functional.silu(qkv32)
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata)
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
state_indices_tensor = attn_metadata.state_indices_tensor

Expand Down Expand Up @@ -559,13 +581,11 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
hidden = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)

hidden = self.norm._forward(hidden)
gate, _ = self.output_gate(hidden_states)
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
hidden = F.sigmoid(gate) * hidden
hidden = hidden.to(hidden_states.dtype)
hidden, _ = self.out_proj(hidden)
return hidden
output[:num_actual_tokens], _ = self.out_proj(hidden)


class MiniMaxText01Attention(nn.Module):
Expand Down Expand Up @@ -635,8 +655,8 @@ def __init__(
)
return

def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
**kwargs) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor, **kwargs) -> None:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
qkv, _ = self.qkv_proj(hidden_states)
Expand All @@ -648,8 +668,7 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
else:
q, k = attn_metadata.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
output[:], _ = self.o_proj(attn_output)


class MiniMaxText01DecoderLayer(nn.Module):
Expand Down Expand Up @@ -794,16 +813,15 @@ def forward(self,
is_warmup: bool = False,
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:

forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
layernorm_input = hidden_states
layernorm_output = self.input_layernorm(layernorm_input)
residual = layernorm_output if self.postnorm else layernorm_input
self_attention_output = self.self_attn(
self_attention_output = torch.empty_like(layernorm_output)
self.self_attn(
hidden_states=layernorm_output,
output=self_attention_output,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)

residual = residual * self.layernorm_attention_alpha
Expand All @@ -817,8 +835,8 @@ def forward(self,
if self.expert_num == 1:
hidden_states = self.mlp(layernorm_output)
else:
moe_hidden_states = self.block_sparse_moe(
copy.deepcopy(layernorm_output))
moe_layernorm_output = layernorm_output.clone()
moe_hidden_states = self.block_sparse_moe(moe_layernorm_output)
if self.shared_moe:
before_moe_dtype = layernorm_output.dtype
moe_hidden_fp32 = moe_hidden_states.to(torch.float32)
Expand Down Expand Up @@ -856,17 +874,15 @@ def shared_moe_coefficient_loader(param: torch.Tensor,
return


@support_torch_compile
class MiniMaxText01Model(nn.Module):

def __init__(
self,
config: MiniMaxConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
scheduler_config=None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: MiniMaxConfig = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
scheduler_config = vllm_config.scheduler_config

self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
Expand Down Expand Up @@ -1019,12 +1035,11 @@ def forward(self,
attn_metadata = forward_context.attn_metadata
if not envs.VLLM_USE_V1 and attn_metadata is None:
return None
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []

if not envs.VLLM_USE_V1:
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []
(
minimax_cache_tensors,
state_indices_tensor,
Expand Down Expand Up @@ -1096,7 +1111,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:

super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
Expand All @@ -1109,12 +1123,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self.unpadded_vocab_size = self.config.vocab_size
if hasattr(vllm_config.model_config, "max_model_len"):
self.config.max_model_len = vllm_config.model_config.max_model_len
self.model = MiniMaxText01Model(
self.config,
quant_config,
cache_config=vllm_config.cache_config,
scheduler_config=vllm_config.scheduler_config,
prefix=maybe_prefix(prefix, "model"))
self.model = MiniMaxText01Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
Expand Down Expand Up @@ -1433,3 +1443,36 @@ def get_mamba_state_shape_from_config(
tp_size=parallel_config.tensor_parallel_size,
head_dim=hf_config.head_dim,
)


def linear_attention(
hidden_states: torch.Tensor,
output: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
print("layer_name: ", layer_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This print statement appears to be a leftover from debugging. It should be removed to avoid polluting logs in production.

self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states,
output=output,
positions=positions,
kv_caches=None)


def linear_attention_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
) -> None:
return


direct_register_custom_op(
op_name="linear_attention",
op_func=linear_attention,
mutates_args=["output"],
fake_impl=linear_attention_fake,
dispatch_key=current_platform.dispatch_key,
)