Skip to content

Commit 1395461

Browse files
authored
[Hybrid][torch.compile] Refactor mamba2 forward to avoid obscuring linear projections under custom op (#28587)
Signed-off-by: Tomer Asida <[email protected]>
1 parent 9912b8c commit 1395461

File tree

7 files changed

+90
-88
lines changed

7 files changed

+90
-88
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 84 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,10 @@ def __init__(
426426
# `ColumnParallelLinear` and `MergedColumnParallelLinear`,
427427
# and `set_weight_attrs` doesn't allow to override it
428428
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
429+
conv_weights = self.conv1d.weight.view(
430+
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
431+
)
432+
self.register_buffer("conv_weights", conv_weights, persistent=False)
429433

430434
# - these are TPed by heads to reduce the size of the
431435
# temporal shape
@@ -459,6 +463,17 @@ def __init__(
459463
intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
460464
)
461465

466+
# - get hidden_states, B and C after depthwise convolution.
467+
self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
468+
hidden_states_B_C,
469+
[
470+
self.intermediate_size // self.tp_size,
471+
self.groups_ssm_state_size // self.tp_size,
472+
self.groups_ssm_state_size // self.tp_size,
473+
],
474+
dim=-1,
475+
)
476+
462477
compilation_config = get_current_vllm_config().compilation_config
463478
if prefix in compilation_config.static_forward_context:
464479
raise ValueError(f"Duplicate layer name: {prefix}")
@@ -470,33 +485,80 @@ def __init__(
470485
self.cache_config = cache_config
471486
self.prefix = prefix
472487

488+
# Pre-compute sizes for forward pass
489+
self.tped_intermediate_size = self.intermediate_size // self.tp_size
490+
self.tped_conv_size = self.conv_dim // self.tp_size
491+
self.tped_dt_size = self.num_heads // self.tp_size
492+
493+
self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
494+
hidden_states_B_C,
495+
[
496+
self.tped_intermediate_size,
497+
self.groups_ssm_state_size // self.tp_size,
498+
self.groups_ssm_state_size // self.tp_size,
499+
],
500+
dim=-1,
501+
)
502+
473503
def forward_native(
474504
self,
475505
hidden_states: torch.Tensor,
476-
output: torch.Tensor,
477506
mup_vector: torch.Tensor | None = None,
478507
):
479508
pass
480509

481510
def forward(
482511
self,
483512
hidden_states: torch.Tensor,
484-
output: torch.Tensor,
485513
mup_vector: torch.Tensor | None = None,
486514
):
515+
# 1. Gated MLP's linear projection
516+
projected_states, _ = self.in_proj(hidden_states)
517+
if mup_vector is not None:
518+
projected_states = projected_states * mup_vector
519+
520+
# 2. Prepare inputs for conv + SSM
521+
ssm_output = torch.empty(
522+
[
523+
hidden_states.shape[0],
524+
(self.num_heads // self.tp_size) * self.head_dim,
525+
],
526+
dtype=hidden_states.dtype,
527+
device=hidden_states.device,
528+
)
529+
530+
# 3. conv + SSM
531+
# (split `projected_states` into hidden_states_B_C, dt in the custom op to
532+
# ensure it is not treated as an intermediate tensor by torch compile)
487533
torch.ops.vllm.mamba_mixer2(
488-
hidden_states,
489-
output,
534+
projected_states,
535+
ssm_output,
490536
self.prefix,
491-
mup_vector,
492537
)
493538

494-
def forward_cuda(
539+
# 4. gated MLP
540+
# GatedRMSNorm internally applying SiLU to the gate
541+
# SiLU is applied internally before normalization, unlike standard
542+
# norm usage
543+
gate = projected_states[..., : self.tped_intermediate_size]
544+
hidden_states = self.norm(ssm_output, gate)
545+
546+
# 5. Final linear projection
547+
output, _ = self.out_proj(hidden_states)
548+
549+
return output
550+
551+
def conv_ssm_forward(
495552
self,
496-
hidden_states: torch.Tensor,
553+
projected_states: torch.Tensor,
497554
output: torch.Tensor,
498-
mup_vector: torch.Tensor | None = None,
499555
):
556+
hidden_states_B_C, dt = torch.split(
557+
projected_states[..., self.tped_intermediate_size :],
558+
[self.tped_conv_size, self.tped_dt_size],
559+
dim=-1,
560+
)
561+
500562
forward_context = get_forward_context()
501563
# attn_metadata contains metadata necessary for the mamba2 triton
502564
# kernels to operate in continuous batching and in chunked prefill
@@ -524,46 +586,13 @@ def forward_cuda(
524586
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
525587
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
526588

527-
# 1. Gated MLP's linear projection
528-
projected_states, _ = self.in_proj(hidden_states)
529-
530-
if mup_vector is not None:
531-
projected_states = projected_states * mup_vector
532-
533-
gate, hidden_states_B_C, dt = torch.split(
534-
projected_states,
535-
[
536-
self.intermediate_size // self.tp_size,
537-
self.conv_dim // self.tp_size,
538-
self.num_heads // self.tp_size,
539-
],
540-
dim=-1,
541-
)
542-
543-
conv_weights = self.conv1d.weight.view(
544-
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
545-
)
546-
547-
# - get hidden_states, B and C after depthwise convolution.
548-
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
549-
hidden_states_B_C,
550-
[
551-
self.intermediate_size // self.tp_size,
552-
self.groups_ssm_state_size // self.tp_size,
553-
self.groups_ssm_state_size // self.tp_size,
554-
],
555-
dim=-1,
556-
)
557-
558589
if attn_metadata is None:
559590
# profile run
560591
hidden_states_B_C = (
561592
hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)
562593
).contiguous()
563-
hidden_states, _B, _C = split_hidden_states_B_C_fn(hidden_states_B_C)
564-
hidden_states = self.norm(hidden_states, gate)
565-
out, _ = self.out_proj(hidden_states)
566-
return out
594+
hidden_states, _B, _C = self.split_hidden_states_B_C_fn(hidden_states_B_C)
595+
return hidden_states
567596

568597
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
569598
num_prefills = attn_metadata.num_prefills # request count
@@ -622,18 +651,8 @@ def forward_cuda(
622651
block_idx_first_scheduled_token_p = None
623652
num_computed_tokens_p = None
624653

625-
# Preallocate output tensor to avoid memcpy cost for merging prefill
626-
# and decode outputs
627-
preallocated_ssm_out = torch.empty(
628-
[
629-
num_prefill_tokens + num_decodes,
630-
(self.num_heads // self.tp_size) * self.head_dim,
631-
],
632-
dtype=hidden_states.dtype,
633-
device=hidden_states.device,
634-
)
635654
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
636-
preallocated_ssm_out,
655+
output[:num_actual_tokens],
637656
[num_decodes, num_prefill_tokens],
638657
dim=0,
639658
)
@@ -658,7 +677,7 @@ def forward_cuda(
658677
) # this is the form that causal-conv see
659678
hidden_states_B_C_p = causal_conv1d_fn(
660679
x,
661-
conv_weights,
680+
self.conv_weights,
662681
self.conv1d.bias,
663682
activation=self.activation,
664683
conv_states=conv_state,
@@ -673,7 +692,9 @@ def forward_cuda(
673692
query_start_loc=query_start_loc_p,
674693
).transpose(0, 1)[:num_prefill_tokens]
675694

676-
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p)
695+
hidden_states_p, B_p, C_p = self.split_hidden_states_B_C_fn(
696+
hidden_states_B_C_p
697+
)
677698

678699
# 3. State Space Model sequence transformation
679700
initial_states = None
@@ -815,15 +836,17 @@ def forward_cuda(
815836
hidden_states_B_C_d = causal_conv1d_update(
816837
hidden_states_B_C_d,
817838
conv_state,
818-
conv_weights,
839+
self.conv_weights,
819840
self.conv1d.bias,
820841
self.activation,
821842
conv_state_indices=state_indices_tensor_d,
822843
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
823844
initial_state_idx=block_idx_last_computed_token_d,
824845
)
825846

826-
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
847+
hidden_states_d, B_d, C_d = self.split_hidden_states_B_C_fn(
848+
hidden_states_B_C_d
849+
)
827850

828851
# 3. State Space Model sequence transformation
829852
n_groups = self.n_groups // self.tp_size
@@ -861,15 +884,6 @@ def forward_cuda(
861884
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
862885
)
863886

864-
# 4. gated MLP
865-
# GatedRMSNorm internally applying SiLU to the gate
866-
# SiLU is applied internally before normalization, unlike standard
867-
# norm usage
868-
hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])
869-
870-
# 5. Final linear projection
871-
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
872-
873887
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
874888
assert self.model_config is not None
875889
assert self.cache_config is not None
@@ -901,21 +915,19 @@ def get_attn_backend(self) -> type["AttentionBackend"]:
901915

902916

903917
def mamba_mixer2(
904-
hidden_states: torch.Tensor,
918+
projected_states: torch.Tensor,
905919
output: torch.Tensor,
906920
layer_name: str,
907-
mup_vector: torch.Tensor | None = None,
908921
) -> None:
909922
forward_context: ForwardContext = get_forward_context()
910923
self = forward_context.no_compile_layers[layer_name]
911-
self.forward_cuda(hidden_states=hidden_states, output=output, mup_vector=mup_vector)
924+
self.conv_ssm_forward(projected_states=projected_states, output=output)
912925

913926

914927
def mamba_mixer2_fake(
915-
hidden_states: torch.Tensor,
928+
projected_states: torch.Tensor,
916929
output: torch.Tensor,
917930
layer_name: str,
918-
mup_vector: torch.Tensor | None = None,
919931
) -> None:
920932
return
921933

vllm/model_executor/models/bamba.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,7 @@ def forward(
138138
else:
139139
hidden_states, residual = self.input_layernorm(hidden_states, residual)
140140

141-
output = torch.empty_like(hidden_states)
142-
self.mamba(hidden_states, output)
141+
output = self.mamba(hidden_states)
143142
# Fully Connected
144143
hidden_states, residual = self.pre_ff_layernorm(output, residual)
145144
hidden_states = self.feed_forward(hidden_states)

vllm/model_executor/models/falcon_h1.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,8 @@ def forward(
198198
residual: torch.Tensor | None,
199199
**kwargs,
200200
):
201-
output = torch.empty_like(hidden_states)
202-
self.mamba(
201+
output = self.mamba(
203202
hidden_states,
204-
output,
205203
mup_vector=self.mup_vector,
206204
)
207205
return output, residual

vllm/model_executor/models/granitemoehybrid.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,7 @@ def forward(
115115
):
116116
residual = hidden_states
117117
hidden_states = self.input_layernorm(hidden_states)
118-
output = torch.empty_like(hidden_states)
119-
self.mamba(hidden_states, output)
118+
output = self.mamba(hidden_states)
120119
hidden_states = residual + output * self.residual_multiplier
121120

122121
residual = hidden_states

vllm/model_executor/models/mamba2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ def forward(
8787
else:
8888
hidden_states, residual = self.norm(hidden_states, residual)
8989

90-
output = torch.empty_like(hidden_states)
91-
self.mixer(hidden_states, output)
90+
output = self.mixer(hidden_states)
9291
return output, residual
9392

9493

vllm/model_executor/models/nemotron_h.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,7 @@ def forward(
376376
else:
377377
hidden_states, residual = self.norm(hidden_states, residual)
378378

379-
output = torch.empty_like(hidden_states)
380-
self.mixer(hidden_states, output)
379+
output = self.mixer(hidden_states)
381380
return output, residual
382381

383382

vllm/model_executor/models/zamba2.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -567,11 +567,7 @@ def forward(
567567
hidden_states = self.input_layernorm(hidden_states)
568568

569569
# Process through Mamba mixer
570-
output = torch.empty_like(hidden_states)
571-
self.mamba(
572-
hidden_states,
573-
output,
574-
)
570+
output = self.mamba(hidden_states)
575571

576572
# residual connection after mamba
577573
hidden_states = residual + output

0 commit comments

Comments
 (0)