@@ -509,21 +509,18 @@ def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
509
509
slot_id , 32 )
510
510
return hidden
511
511
512
- def forward (self , hidden_states : torch .Tensor , output : torch .Tensor ,
513
- positions : torch .Tensor ,
512
+ def forward (self , hidden_states : torch .Tensor , positions : torch .Tensor ,
514
513
kv_caches : MinimaxCacheParams ) -> torch .Tensor :
515
514
if not envs .VLLM_USE_V1 :
516
- self ._forward (hidden_states , output , positions , kv_caches )
515
+ return self ._forward (hidden_states , positions , kv_caches )
517
516
else :
518
- torch .ops .vllm .linear_attention (
517
+ return torch .ops .vllm .linear_attention (
519
518
hidden_states ,
520
- output ,
521
519
positions ,
522
520
self .prefix ,
523
521
)
524
522
525
- def _forward (self , hidden_states : torch .Tensor , output : torch .Tensor ,
526
- positions : torch .Tensor ,
523
+ def _forward (self , hidden_states : torch .Tensor , positions : torch .Tensor ,
527
524
kv_caches : MinimaxCacheParams ) -> torch .Tensor :
528
525
forward_context = get_forward_context ()
529
526
attn_metadata : AttentionMetadata = forward_context .attn_metadata
@@ -585,7 +582,8 @@ def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
585
582
gate , _ = self .output_gate (hidden_states [:num_actual_tokens ])
586
583
hidden = F .sigmoid (gate ) * hidden
587
584
hidden = hidden .to (hidden_states .dtype )
588
- output [:num_actual_tokens ], _ = self .out_proj (hidden )
585
+ output , _ = self .out_proj (hidden )
586
+ return output [:num_actual_tokens ]
589
587
590
588
591
589
class MiniMaxText01Attention (nn .Module ):
@@ -655,8 +653,8 @@ def __init__(
655
653
)
656
654
return
657
655
658
- def forward (self , hidden_states : torch .Tensor , output : torch .Tensor ,
659
- positions : torch . Tensor , ** kwargs ) -> None :
656
+ def forward (self , hidden_states : torch .Tensor , positions : torch .Tensor ,
657
+ ** kwargs ) -> None :
660
658
forward_context = get_forward_context ()
661
659
attn_metadata = forward_context .attn_metadata
662
660
qkv , _ = self .qkv_proj (hidden_states )
@@ -668,7 +666,8 @@ def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
668
666
else :
669
667
q , k = attn_metadata .rotary_emb (positions , q , k )
670
668
attn_output = self .attn (q , k , v )
671
- output [:], _ = self .o_proj (attn_output )
669
+ output , _ = self .o_proj (attn_output )
670
+ return output
672
671
673
672
674
673
class MiniMaxText01DecoderLayer (nn .Module ):
@@ -816,10 +815,8 @@ def forward(self,
816
815
layernorm_input = hidden_states
817
816
layernorm_output = self .input_layernorm (layernorm_input )
818
817
residual = layernorm_output if self .postnorm else layernorm_input
819
- self_attention_output = torch .empty_like (layernorm_output )
820
- self .self_attn (
818
+ self_attention_output = self .self_attn (
821
819
hidden_states = layernorm_output ,
822
- output = self_attention_output ,
823
820
positions = positions ,
824
821
kv_caches = kv_caches ,
825
822
)
@@ -1447,32 +1444,29 @@ def get_mamba_state_shape_from_config(
1447
1444
1448
1445
def linear_attention (
1449
1446
hidden_states : torch .Tensor ,
1450
- output : torch .Tensor ,
1451
1447
positions : torch .Tensor ,
1452
1448
layer_name : str ,
1453
- ) -> None :
1449
+ ) -> torch . Tensor :
1454
1450
forward_context : ForwardContext = get_forward_context ()
1455
- print ("layer_name: " , layer_name )
1456
1451
self = forward_context .no_compile_layers [layer_name ]
1457
- self ._forward (hidden_states = hidden_states ,
1458
- output = output ,
1459
- positions = positions ,
1460
- kv_caches = None )
1452
+ output = self ._forward (hidden_states = hidden_states ,
1453
+ positions = positions ,
1454
+ kv_caches = None )
1455
+ return output
1461
1456
1462
1457
1463
1458
def linear_attention_fake (
1464
1459
hidden_states : torch .Tensor ,
1465
- output : torch .Tensor ,
1466
1460
positions : torch .Tensor ,
1467
1461
layer_name : str ,
1468
- ) -> None :
1469
- return
1462
+ ) -> torch . tensor :
1463
+ return torch . empty_like ( hidden_states )
1470
1464
1471
1465
1472
1466
direct_register_custom_op (
1473
1467
op_name = "linear_attention" ,
1474
1468
op_func = linear_attention ,
1475
- mutates_args = ["output" ],
1469
+ mutates_args = [],
1476
1470
fake_impl = linear_attention_fake ,
1477
1471
dispatch_key = current_platform .dispatch_key ,
1478
1472
)
0 commit comments