1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
"""Inference-only MiniMaxText01 model."""
4
- import copy
5
4
import math
6
5
from collections .abc import Iterable
7
6
from typing import Optional , Union
16
15
17
16
from vllm import envs
18
17
from vllm .attention import Attention , AttentionMetadata
18
+ from vllm .compilation .decorators import support_torch_compile
19
19
from vllm .config import CacheConfig , VllmConfig , get_current_vllm_config
20
20
from vllm .distributed .communication_op import tensor_model_parallel_all_reduce
21
21
from vllm .distributed .parallel_state import (
22
22
get_pp_group , get_tensor_model_parallel_rank ,
23
23
get_tensor_model_parallel_world_size )
24
- from vllm .forward_context import get_forward_context
24
+ from vllm .forward_context import ForwardContext , get_forward_context
25
25
from vllm .model_executor .custom_op import CustomOp
26
26
from vllm .model_executor .layers .activation import SiluAndMul
27
27
from vllm .model_executor .layers .fused_moe import FusedMoE
44
44
from vllm .model_executor .model_loader .weight_utils import default_weight_loader
45
45
from vllm .model_executor .models .utils import maybe_prefix
46
46
from vllm .model_executor .sampling_metadata import SamplingMetadata
47
+ from vllm .platforms import current_platform
47
48
from vllm .sequence import IntermediateTensors
49
+ from vllm .utils import direct_register_custom_op
48
50
from vllm .v1 .attention .backends .linear_attn import LinearAttentionMetadata
49
51
50
52
from .interfaces import HasInnerState , IsHybrid
@@ -507,20 +509,41 @@ def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
507
509
slot_id , 32 )
508
510
return hidden
509
511
510
- def forward (self , hidden_states : torch .Tensor , positions : torch .Tensor ,
511
- kv_caches : MinimaxCacheParams , ** kwargs ) -> torch .Tensor :
512
- qkv , _ = self .qkv_proj (hidden_states )
512
+ def forward (self , hidden_states : torch .Tensor , output : torch .Tensor ,
513
+ positions : torch .Tensor ,
514
+ kv_caches : MinimaxCacheParams ) -> torch .Tensor :
515
+ if not envs .VLLM_USE_V1 :
516
+ self ._forward (hidden_states , output , positions , kv_caches )
517
+ else :
518
+ torch .ops .vllm .linear_attention (
519
+ hidden_states ,
520
+ output ,
521
+ positions ,
522
+ self .prefix ,
523
+ )
524
+
525
+ def _forward (self , hidden_states : torch .Tensor , output : torch .Tensor ,
526
+ positions : torch .Tensor ,
527
+ kv_caches : MinimaxCacheParams ) -> torch .Tensor :
528
+ forward_context = get_forward_context ()
529
+ attn_metadata : AttentionMetadata = forward_context .attn_metadata
530
+ if envs .VLLM_USE_V1 and attn_metadata is not None :
531
+ assert isinstance (attn_metadata , dict )
532
+ attn_metadata = attn_metadata [self .prefix ]
533
+ assert isinstance (attn_metadata , LinearAttentionMetadata )
534
+ num_actual_tokens = attn_metadata .num_prefill_tokens + \
535
+ attn_metadata .num_decode_tokens
536
+ else :
537
+ num_actual_tokens = hidden_states .shape [0 ]
538
+
539
+ qkv , _ = self .qkv_proj (hidden_states [:num_actual_tokens ])
513
540
qkv32 = qkv .to (torch .float32 )
514
541
qkvact = torch .nn .functional .silu (qkv32 )
515
542
qkvact = qkvact .view ((qkv .shape [0 ], self .tp_heads , - 1 ))
516
543
q , k , v = torch .split (qkvact , [self .head_dim ] * 3 , dim = - 1 )
517
- forward_context = get_forward_context ()
518
- attn_metadata = forward_context .attn_metadata
544
+
519
545
if envs .VLLM_USE_V1 :
520
546
if attn_metadata is not None :
521
- assert isinstance (attn_metadata , dict )
522
- attn_metadata = attn_metadata [self .prefix ]
523
- assert isinstance (attn_metadata , LinearAttentionMetadata )
524
547
kv_cache = self .kv_cache [forward_context .virtual_engine ][0 ]
525
548
state_indices_tensor = attn_metadata .state_indices_tensor
526
549
@@ -559,13 +582,11 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
559
582
hidden = self ._decode_infer (q , k , v , kv_cache ,
560
583
state_indices_tensor ,
561
584
attn_metadata )
562
-
563
585
hidden = self .norm ._forward (hidden )
564
- gate , _ = self .output_gate (hidden_states )
586
+ gate , _ = self .output_gate (hidden_states [: num_actual_tokens ] )
565
587
hidden = F .sigmoid (gate ) * hidden
566
588
hidden = hidden .to (hidden_states .dtype )
567
- hidden , _ = self .out_proj (hidden )
568
- return hidden
589
+ output [:num_actual_tokens ], _ = self .out_proj (hidden )
569
590
570
591
571
592
class MiniMaxText01Attention (nn .Module ):
@@ -635,8 +656,8 @@ def __init__(
635
656
)
636
657
return
637
658
638
- def forward (self , hidden_states : torch .Tensor , positions : torch .Tensor ,
639
- ** kwargs ) -> torch . Tensor :
659
+ def forward (self , hidden_states : torch .Tensor , output : torch .Tensor ,
660
+ positions : torch . Tensor , ** kwargs ) -> None :
640
661
forward_context = get_forward_context ()
641
662
attn_metadata = forward_context .attn_metadata
642
663
qkv , _ = self .qkv_proj (hidden_states )
@@ -648,8 +669,7 @@ def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
648
669
else :
649
670
q , k = attn_metadata .rotary_emb (positions , q , k )
650
671
attn_output = self .attn (q , k , v )
651
- output , _ = self .o_proj (attn_output )
652
- return output
672
+ output [:], _ = self .o_proj (attn_output )
653
673
654
674
655
675
class MiniMaxText01DecoderLayer (nn .Module ):
@@ -794,16 +814,15 @@ def forward(self,
794
814
is_warmup : bool = False ,
795
815
** kwargs ) -> tuple [torch .Tensor , torch .Tensor ]:
796
816
797
- forward_context = get_forward_context ()
798
- attn_metadata = forward_context .attn_metadata
799
817
layernorm_input = hidden_states
800
818
layernorm_output = self .input_layernorm (layernorm_input )
801
819
residual = layernorm_output if self .postnorm else layernorm_input
802
- self_attention_output = self .self_attn (
820
+ self_attention_output = torch .empty_like (layernorm_output )
821
+ self .self_attn (
803
822
hidden_states = layernorm_output ,
823
+ output = self_attention_output ,
804
824
positions = positions ,
805
825
kv_caches = kv_caches ,
806
- attn_metadata = attn_metadata ,
807
826
)
808
827
809
828
residual = residual * self .layernorm_attention_alpha
@@ -817,8 +836,8 @@ def forward(self,
817
836
if self .expert_num == 1 :
818
837
hidden_states = self .mlp (layernorm_output )
819
838
else :
820
- moe_hidden_states = self . block_sparse_moe (
821
- copy . deepcopy ( layernorm_output ) )
839
+ moe_layernorm_output = layernorm_output . clone ()
840
+ moe_hidden_states = self . block_sparse_moe ( moe_layernorm_output )
822
841
if self .shared_moe :
823
842
before_moe_dtype = layernorm_output .dtype
824
843
moe_hidden_fp32 = moe_hidden_states .to (torch .float32 )
@@ -856,17 +875,15 @@ def shared_moe_coefficient_loader(param: torch.Tensor,
856
875
return
857
876
858
877
878
+ @support_torch_compile
859
879
class MiniMaxText01Model (nn .Module ):
860
880
861
- def __init__ (
862
- self ,
863
- config : MiniMaxConfig ,
864
- quant_config : Optional [QuantizationConfig ] = None ,
865
- cache_config : Optional [CacheConfig ] = None ,
866
- scheduler_config = None ,
867
- prefix : str = "" ,
868
- ) -> None :
881
+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
869
882
super ().__init__ ()
883
+ config : MiniMaxConfig = vllm_config .model_config .hf_config
884
+ cache_config = vllm_config .cache_config
885
+ quant_config = vllm_config .quant_config
886
+ scheduler_config = vllm_config .scheduler_config
870
887
871
888
self .padding_idx = config .pad_token_id
872
889
self .vocab_size = config .vocab_size
@@ -1019,12 +1036,11 @@ def forward(self,
1019
1036
attn_metadata = forward_context .attn_metadata
1020
1037
if not envs .VLLM_USE_V1 and attn_metadata is None :
1021
1038
return None
1022
- if "request_ids_to_seq_ids" not in kwargs :
1023
- kwargs ["request_ids_to_seq_ids" ] = {}
1024
- if "finished_requests_ids" not in kwargs :
1025
- kwargs ["finished_requests_ids" ] = []
1026
-
1027
1039
if not envs .VLLM_USE_V1 :
1040
+ if "request_ids_to_seq_ids" not in kwargs :
1041
+ kwargs ["request_ids_to_seq_ids" ] = {}
1042
+ if "finished_requests_ids" not in kwargs :
1043
+ kwargs ["finished_requests_ids" ] = []
1028
1044
(
1029
1045
minimax_cache_tensors ,
1030
1046
state_indices_tensor ,
@@ -1096,7 +1112,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
1096
1112
1097
1113
super ().__init__ ()
1098
1114
config = vllm_config .model_config .hf_config
1099
- quant_config = vllm_config .quant_config
1100
1115
lora_config = vllm_config .lora_config
1101
1116
self .config = config
1102
1117
self .lora_config = lora_config
@@ -1109,12 +1124,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
1109
1124
self .unpadded_vocab_size = self .config .vocab_size
1110
1125
if hasattr (vllm_config .model_config , "max_model_len" ):
1111
1126
self .config .max_model_len = vllm_config .model_config .max_model_len
1112
- self .model = MiniMaxText01Model (
1113
- self .config ,
1114
- quant_config ,
1115
- cache_config = vllm_config .cache_config ,
1116
- scheduler_config = vllm_config .scheduler_config ,
1117
- prefix = maybe_prefix (prefix , "model" ))
1127
+ self .model = MiniMaxText01Model (vllm_config = vllm_config ,
1128
+ prefix = maybe_prefix (prefix , "model" ))
1118
1129
if get_pp_group ().is_last_rank :
1119
1130
self .lm_head = ParallelLMHead (
1120
1131
self .unpadded_vocab_size ,
@@ -1433,3 +1444,36 @@ def get_mamba_state_shape_from_config(
1433
1444
tp_size = parallel_config .tensor_parallel_size ,
1434
1445
head_dim = hf_config .head_dim ,
1435
1446
)
1447
+
1448
+
1449
+ def linear_attention (
1450
+ hidden_states : torch .Tensor ,
1451
+ output : torch .Tensor ,
1452
+ positions : torch .Tensor ,
1453
+ layer_name : str ,
1454
+ ) -> None :
1455
+ forward_context : ForwardContext = get_forward_context ()
1456
+ print ("layer_name: " , layer_name )
1457
+ self = forward_context .no_compile_layers [layer_name ]
1458
+ self ._forward (hidden_states = hidden_states ,
1459
+ output = output ,
1460
+ positions = positions ,
1461
+ kv_caches = None )
1462
+
1463
+
1464
+ def linear_attention_fake (
1465
+ hidden_states : torch .Tensor ,
1466
+ output : torch .Tensor ,
1467
+ positions : torch .Tensor ,
1468
+ layer_name : str ,
1469
+ ) -> None :
1470
+ return
1471
+
1472
+
1473
+ direct_register_custom_op (
1474
+ op_name = "linear_attention" ,
1475
+ op_func = linear_attention ,
1476
+ mutates_args = ["output" ],
1477
+ fake_impl = linear_attention_fake ,
1478
+ dispatch_key = current_platform .dispatch_key ,
1479
+ )
0 commit comments