Skip to content

Commit bf57695

Browse files
committed
Try to use return instead of mutate
Signed-off-by: Thomas Parnell <[email protected]>
1 parent c698db3 commit bf57695

File tree

1 file changed

+19
-25
lines changed

1 file changed

+19
-25
lines changed

vllm/model_executor/models/minimax_text_01.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -509,21 +509,18 @@ def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
509509
slot_id, 32)
510510
return hidden
511511

512-
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
513-
positions: torch.Tensor,
512+
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
514513
kv_caches: MinimaxCacheParams) -> torch.Tensor:
515514
if not envs.VLLM_USE_V1:
516-
self._forward(hidden_states, output, positions, kv_caches)
515+
return self._forward(hidden_states, positions, kv_caches)
517516
else:
518-
torch.ops.vllm.linear_attention(
517+
return torch.ops.vllm.linear_attention(
519518
hidden_states,
520-
output,
521519
positions,
522520
self.prefix,
523521
)
524522

525-
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
526-
positions: torch.Tensor,
523+
def _forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
527524
kv_caches: MinimaxCacheParams) -> torch.Tensor:
528525
forward_context = get_forward_context()
529526
attn_metadata: AttentionMetadata = forward_context.attn_metadata
@@ -585,7 +582,8 @@ def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
585582
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
586583
hidden = F.sigmoid(gate) * hidden
587584
hidden = hidden.to(hidden_states.dtype)
588-
output[:num_actual_tokens], _ = self.out_proj(hidden)
585+
output, _ = self.out_proj(hidden)
586+
return output[:num_actual_tokens]
589587

590588

591589
class MiniMaxText01Attention(nn.Module):
@@ -655,8 +653,8 @@ def __init__(
655653
)
656654
return
657655

658-
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
659-
positions: torch.Tensor, **kwargs) -> None:
656+
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
657+
**kwargs) -> None:
660658
forward_context = get_forward_context()
661659
attn_metadata = forward_context.attn_metadata
662660
qkv, _ = self.qkv_proj(hidden_states)
@@ -668,7 +666,8 @@ def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
668666
else:
669667
q, k = attn_metadata.rotary_emb(positions, q, k)
670668
attn_output = self.attn(q, k, v)
671-
output[:], _ = self.o_proj(attn_output)
669+
output, _ = self.o_proj(attn_output)
670+
return output
672671

673672

674673
class MiniMaxText01DecoderLayer(nn.Module):
@@ -816,10 +815,8 @@ def forward(self,
816815
layernorm_input = hidden_states
817816
layernorm_output = self.input_layernorm(layernorm_input)
818817
residual = layernorm_output if self.postnorm else layernorm_input
819-
self_attention_output = torch.empty_like(layernorm_output)
820-
self.self_attn(
818+
self_attention_output = self.self_attn(
821819
hidden_states=layernorm_output,
822-
output=self_attention_output,
823820
positions=positions,
824821
kv_caches=kv_caches,
825822
)
@@ -1447,32 +1444,29 @@ def get_mamba_state_shape_from_config(
14471444

14481445
def linear_attention(
14491446
hidden_states: torch.Tensor,
1450-
output: torch.Tensor,
14511447
positions: torch.Tensor,
14521448
layer_name: str,
1453-
) -> None:
1449+
) -> torch.Tensor:
14541450
forward_context: ForwardContext = get_forward_context()
1455-
print("layer_name: ", layer_name)
14561451
self = forward_context.no_compile_layers[layer_name]
1457-
self._forward(hidden_states=hidden_states,
1458-
output=output,
1459-
positions=positions,
1460-
kv_caches=None)
1452+
output = self._forward(hidden_states=hidden_states,
1453+
positions=positions,
1454+
kv_caches=None)
1455+
return output
14611456

14621457

14631458
def linear_attention_fake(
14641459
hidden_states: torch.Tensor,
1465-
output: torch.Tensor,
14661460
positions: torch.Tensor,
14671461
layer_name: str,
1468-
) -> None:
1469-
return
1462+
) -> torch.tensor:
1463+
return torch.empty_like(hidden_states)
14701464

14711465

14721466
direct_register_custom_op(
14731467
op_name="linear_attention",
14741468
op_func=linear_attention,
1475-
mutates_args=["output"],
1469+
mutates_args=[],
14761470
fake_impl=linear_attention_fake,
14771471
dispatch_key=current_platform.dispatch_key,
14781472
)

0 commit comments

Comments
 (0)