Skip to content

Commit 9ad6271

Browse files
committed
Enable compile for minimax
Signed-off-by: Thomas Parnell <[email protected]>
1 parent 42172ad commit 9ad6271

File tree

2 files changed

+89
-44
lines changed

2 files changed

+89
-44
lines changed

vllm/config/compilation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,4 +425,5 @@ def set_splitting_ops_for_v1(self):
425425
"vllm.unified_attention",
426426
"vllm.unified_attention_with_output",
427427
"vllm.mamba_mixer2",
428+
"vllm.linear_attention",
428429
]

vllm/model_executor/models/minimax_text_01.py

Lines changed: 88 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Inference-only MiniMaxText01 model."""
4-
import copy
54
import math
65
from collections.abc import Iterable
76
from typing import Optional, Union
@@ -16,12 +15,13 @@
1615

1716
from vllm import envs
1817
from vllm.attention import Attention, AttentionMetadata
18+
from vllm.compilation.decorators import support_torch_compile
1919
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
2020
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
2121
from vllm.distributed.parallel_state import (
2222
get_pp_group, get_tensor_model_parallel_rank,
2323
get_tensor_model_parallel_world_size)
24-
from vllm.forward_context import get_forward_context
24+
from vllm.forward_context import ForwardContext, get_forward_context
2525
from vllm.model_executor.custom_op import CustomOp
2626
from vllm.model_executor.layers.activation import SiluAndMul
2727
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -44,7 +44,9 @@
4444
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
4545
from vllm.model_executor.models.utils import maybe_prefix
4646
from vllm.model_executor.sampling_metadata import SamplingMetadata
47+
from vllm.platforms import current_platform
4748
from vllm.sequence import IntermediateTensors
49+
from vllm.utils import direct_register_custom_op
4850
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
4951

5052
from .interfaces import HasInnerState, IsHybrid
@@ -507,20 +509,41 @@ def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
507509
slot_id, 32)
508510
return hidden
509511

510-
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
511-
kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor:
512-
qkv, _ = self.qkv_proj(hidden_states)
512+
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
513+
positions: torch.Tensor,
514+
kv_caches: MinimaxCacheParams) -> torch.Tensor:
515+
if not envs.VLLM_USE_V1:
516+
self._forward(hidden_states, output, positions, kv_caches)
517+
else:
518+
torch.ops.vllm.linear_attention(
519+
hidden_states,
520+
output,
521+
positions,
522+
self.prefix,
523+
)
524+
525+
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
526+
positions: torch.Tensor,
527+
kv_caches: MinimaxCacheParams) -> torch.Tensor:
528+
forward_context = get_forward_context()
529+
attn_metadata: AttentionMetadata = forward_context.attn_metadata
530+
if envs.VLLM_USE_V1 and attn_metadata is not None:
531+
assert isinstance(attn_metadata, dict)
532+
attn_metadata = attn_metadata[self.prefix]
533+
assert isinstance(attn_metadata, LinearAttentionMetadata)
534+
num_actual_tokens = attn_metadata.num_prefill_tokens + \
535+
attn_metadata.num_decode_tokens
536+
else:
537+
num_actual_tokens = hidden_states.shape[0]
538+
539+
qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens])
513540
qkv32 = qkv.to(torch.float32)
514541
qkvact = torch.nn.functional.silu(qkv32)
515542
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
516543
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
517-
forward_context = get_forward_context()
518-
attn_metadata = forward_context.attn_metadata
544+
519545
if envs.VLLM_USE_V1:
520546
if attn_metadata is not None:
521-
assert isinstance(attn_metadata, dict)
522-
attn_metadata = attn_metadata[self.prefix]
523-
assert isinstance(attn_metadata, LinearAttentionMetadata)
524547
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
525548
state_indices_tensor = attn_metadata.state_indices_tensor
526549

@@ -559,13 +582,11 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
559582
hidden = self._decode_infer(q, k, v, kv_cache,
560583
state_indices_tensor,
561584
attn_metadata)
562-
563585
hidden = self.norm._forward(hidden)
564-
gate, _ = self.output_gate(hidden_states)
586+
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
565587
hidden = F.sigmoid(gate) * hidden
566588
hidden = hidden.to(hidden_states.dtype)
567-
hidden, _ = self.out_proj(hidden)
568-
return hidden
589+
output[:num_actual_tokens], _ = self.out_proj(hidden)
569590

570591

571592
class MiniMaxText01Attention(nn.Module):
@@ -635,8 +656,8 @@ def __init__(
635656
)
636657
return
637658

638-
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
639-
**kwargs) -> torch.Tensor:
659+
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
660+
positions: torch.Tensor, **kwargs) -> None:
640661
forward_context = get_forward_context()
641662
attn_metadata = forward_context.attn_metadata
642663
qkv, _ = self.qkv_proj(hidden_states)
@@ -648,8 +669,7 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
648669
else:
649670
q, k = attn_metadata.rotary_emb(positions, q, k)
650671
attn_output = self.attn(q, k, v)
651-
output, _ = self.o_proj(attn_output)
652-
return output
672+
output[:], _ = self.o_proj(attn_output)
653673

654674

655675
class MiniMaxText01DecoderLayer(nn.Module):
@@ -794,16 +814,15 @@ def forward(self,
794814
is_warmup: bool = False,
795815
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
796816

797-
forward_context = get_forward_context()
798-
attn_metadata = forward_context.attn_metadata
799817
layernorm_input = hidden_states
800818
layernorm_output = self.input_layernorm(layernorm_input)
801819
residual = layernorm_output if self.postnorm else layernorm_input
802-
self_attention_output = self.self_attn(
820+
self_attention_output = torch.empty_like(layernorm_output)
821+
self.self_attn(
803822
hidden_states=layernorm_output,
823+
output=self_attention_output,
804824
positions=positions,
805825
kv_caches=kv_caches,
806-
attn_metadata=attn_metadata,
807826
)
808827

809828
residual = residual * self.layernorm_attention_alpha
@@ -817,8 +836,8 @@ def forward(self,
817836
if self.expert_num == 1:
818837
hidden_states = self.mlp(layernorm_output)
819838
else:
820-
moe_hidden_states = self.block_sparse_moe(
821-
copy.deepcopy(layernorm_output))
839+
moe_layernorm_output = layernorm_output.clone()
840+
moe_hidden_states = self.block_sparse_moe(moe_layernorm_output)
822841
if self.shared_moe:
823842
before_moe_dtype = layernorm_output.dtype
824843
moe_hidden_fp32 = moe_hidden_states.to(torch.float32)
@@ -856,17 +875,15 @@ def shared_moe_coefficient_loader(param: torch.Tensor,
856875
return
857876

858877

878+
@support_torch_compile
859879
class MiniMaxText01Model(nn.Module):
860880

861-
def __init__(
862-
self,
863-
config: MiniMaxConfig,
864-
quant_config: Optional[QuantizationConfig] = None,
865-
cache_config: Optional[CacheConfig] = None,
866-
scheduler_config=None,
867-
prefix: str = "",
868-
) -> None:
881+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
869882
super().__init__()
883+
config: MiniMaxConfig = vllm_config.model_config.hf_config
884+
cache_config = vllm_config.cache_config
885+
quant_config = vllm_config.quant_config
886+
scheduler_config = vllm_config.scheduler_config
870887

871888
self.padding_idx = config.pad_token_id
872889
self.vocab_size = config.vocab_size
@@ -1019,12 +1036,11 @@ def forward(self,
10191036
attn_metadata = forward_context.attn_metadata
10201037
if not envs.VLLM_USE_V1 and attn_metadata is None:
10211038
return None
1022-
if "request_ids_to_seq_ids" not in kwargs:
1023-
kwargs["request_ids_to_seq_ids"] = {}
1024-
if "finished_requests_ids" not in kwargs:
1025-
kwargs["finished_requests_ids"] = []
1026-
10271039
if not envs.VLLM_USE_V1:
1040+
if "request_ids_to_seq_ids" not in kwargs:
1041+
kwargs["request_ids_to_seq_ids"] = {}
1042+
if "finished_requests_ids" not in kwargs:
1043+
kwargs["finished_requests_ids"] = []
10281044
(
10291045
minimax_cache_tensors,
10301046
state_indices_tensor,
@@ -1096,7 +1112,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
10961112

10971113
super().__init__()
10981114
config = vllm_config.model_config.hf_config
1099-
quant_config = vllm_config.quant_config
11001115
lora_config = vllm_config.lora_config
11011116
self.config = config
11021117
self.lora_config = lora_config
@@ -1109,12 +1124,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
11091124
self.unpadded_vocab_size = self.config.vocab_size
11101125
if hasattr(vllm_config.model_config, "max_model_len"):
11111126
self.config.max_model_len = vllm_config.model_config.max_model_len
1112-
self.model = MiniMaxText01Model(
1113-
self.config,
1114-
quant_config,
1115-
cache_config=vllm_config.cache_config,
1116-
scheduler_config=vllm_config.scheduler_config,
1117-
prefix=maybe_prefix(prefix, "model"))
1127+
self.model = MiniMaxText01Model(vllm_config=vllm_config,
1128+
prefix=maybe_prefix(prefix, "model"))
11181129
if get_pp_group().is_last_rank:
11191130
self.lm_head = ParallelLMHead(
11201131
self.unpadded_vocab_size,
@@ -1433,3 +1444,36 @@ def get_mamba_state_shape_from_config(
14331444
tp_size=parallel_config.tensor_parallel_size,
14341445
head_dim=hf_config.head_dim,
14351446
)
1447+
1448+
1449+
def linear_attention(
1450+
hidden_states: torch.Tensor,
1451+
output: torch.Tensor,
1452+
positions: torch.Tensor,
1453+
layer_name: str,
1454+
) -> None:
1455+
forward_context: ForwardContext = get_forward_context()
1456+
print("layer_name: ", layer_name)
1457+
self = forward_context.no_compile_layers[layer_name]
1458+
self._forward(hidden_states=hidden_states,
1459+
output=output,
1460+
positions=positions,
1461+
kv_caches=None)
1462+
1463+
1464+
def linear_attention_fake(
1465+
hidden_states: torch.Tensor,
1466+
output: torch.Tensor,
1467+
positions: torch.Tensor,
1468+
layer_name: str,
1469+
) -> None:
1470+
return
1471+
1472+
1473+
direct_register_custom_op(
1474+
op_name="linear_attention",
1475+
op_func=linear_attention,
1476+
mutates_args=["output"],
1477+
fake_impl=linear_attention_fake,
1478+
dispatch_key=current_platform.dispatch_key,
1479+
)

0 commit comments

Comments
 (0)