@@ -426,6 +426,10 @@ def __init__(
426426 # `ColumnParallelLinear` and `MergedColumnParallelLinear`,
427427 # and `set_weight_attrs` doesn't allow to override it
428428 self .conv1d .weight .data = self .conv1d .weight .data .unsqueeze (1 )
429+ conv_weights = self .conv1d .weight .view (
430+ self .conv1d .weight .size (0 ), self .conv1d .weight .size (2 )
431+ )
432+ self .register_buffer ("conv_weights" , conv_weights , persistent = False )
429433
430434 # - these are TPed by heads to reduce the size of the
431435 # temporal shape
@@ -459,6 +463,17 @@ def __init__(
459463 intermediate_size , n_groups , self .use_rms_norm , eps = rms_norm_eps
460464 )
461465
466+ # - get hidden_states, B and C after depthwise convolution.
467+ self .split_hidden_states_B_C_fn = lambda hidden_states_B_C : torch .split (
468+ hidden_states_B_C ,
469+ [
470+ self .intermediate_size // self .tp_size ,
471+ self .groups_ssm_state_size // self .tp_size ,
472+ self .groups_ssm_state_size // self .tp_size ,
473+ ],
474+ dim = - 1 ,
475+ )
476+
462477 compilation_config = get_current_vllm_config ().compilation_config
463478 if prefix in compilation_config .static_forward_context :
464479 raise ValueError (f"Duplicate layer name: { prefix } " )
@@ -470,33 +485,80 @@ def __init__(
470485 self .cache_config = cache_config
471486 self .prefix = prefix
472487
488+ # Pre-compute sizes for forward pass
489+ self .tped_intermediate_size = self .intermediate_size // self .tp_size
490+ self .tped_conv_size = self .conv_dim // self .tp_size
491+ self .tped_dt_size = self .num_heads // self .tp_size
492+
493+ self .split_hidden_states_B_C_fn = lambda hidden_states_B_C : torch .split (
494+ hidden_states_B_C ,
495+ [
496+ self .tped_intermediate_size ,
497+ self .groups_ssm_state_size // self .tp_size ,
498+ self .groups_ssm_state_size // self .tp_size ,
499+ ],
500+ dim = - 1 ,
501+ )
502+
473503 def forward_native (
474504 self ,
475505 hidden_states : torch .Tensor ,
476- output : torch .Tensor ,
477506 mup_vector : torch .Tensor | None = None ,
478507 ):
479508 pass
480509
481510 def forward (
482511 self ,
483512 hidden_states : torch .Tensor ,
484- output : torch .Tensor ,
485513 mup_vector : torch .Tensor | None = None ,
486514 ):
515+ # 1. Gated MLP's linear projection
516+ projected_states , _ = self .in_proj (hidden_states )
517+ if mup_vector is not None :
518+ projected_states = projected_states * mup_vector
519+
520+ # 2. Prepare inputs for conv + SSM
521+ ssm_output = torch .empty (
522+ [
523+ hidden_states .shape [0 ],
524+ (self .num_heads // self .tp_size ) * self .head_dim ,
525+ ],
526+ dtype = hidden_states .dtype ,
527+ device = hidden_states .device ,
528+ )
529+
530+ # 3. conv + SSM
531+ # (split `projected_states` into hidden_states_B_C, dt in the custom op to
532+ # ensure it is not treated as an intermediate tensor by torch compile)
487533 torch .ops .vllm .mamba_mixer2 (
488- hidden_states ,
489- output ,
534+ projected_states ,
535+ ssm_output ,
490536 self .prefix ,
491- mup_vector ,
492537 )
493538
494- def forward_cuda (
539+ # 4. gated MLP
540+ # GatedRMSNorm internally applying SiLU to the gate
541+ # SiLU is applied internally before normalization, unlike standard
542+ # norm usage
543+ gate = projected_states [..., : self .tped_intermediate_size ]
544+ hidden_states = self .norm (ssm_output , gate )
545+
546+ # 5. Final linear projection
547+ output , _ = self .out_proj (hidden_states )
548+
549+ return output
550+
551+ def conv_ssm_forward (
495552 self ,
496- hidden_states : torch .Tensor ,
553+ projected_states : torch .Tensor ,
497554 output : torch .Tensor ,
498- mup_vector : torch .Tensor | None = None ,
499555 ):
556+ hidden_states_B_C , dt = torch .split (
557+ projected_states [..., self .tped_intermediate_size :],
558+ [self .tped_conv_size , self .tped_dt_size ],
559+ dim = - 1 ,
560+ )
561+
500562 forward_context = get_forward_context ()
501563 # attn_metadata contains metadata necessary for the mamba2 triton
502564 # kernels to operate in continuous batching and in chunked prefill
@@ -524,46 +586,13 @@ def forward_cuda(
524586 cu_chunk_seqlen_p = attn_metadata .cu_chunk_seqlen_p
525587 last_chunk_indices_p = attn_metadata .last_chunk_indices_p
526588
527- # 1. Gated MLP's linear projection
528- projected_states , _ = self .in_proj (hidden_states )
529-
530- if mup_vector is not None :
531- projected_states = projected_states * mup_vector
532-
533- gate , hidden_states_B_C , dt = torch .split (
534- projected_states ,
535- [
536- self .intermediate_size // self .tp_size ,
537- self .conv_dim // self .tp_size ,
538- self .num_heads // self .tp_size ,
539- ],
540- dim = - 1 ,
541- )
542-
543- conv_weights = self .conv1d .weight .view (
544- self .conv1d .weight .size (0 ), self .conv1d .weight .size (2 )
545- )
546-
547- # - get hidden_states, B and C after depthwise convolution.
548- split_hidden_states_B_C_fn = lambda hidden_states_B_C : torch .split (
549- hidden_states_B_C ,
550- [
551- self .intermediate_size // self .tp_size ,
552- self .groups_ssm_state_size // self .tp_size ,
553- self .groups_ssm_state_size // self .tp_size ,
554- ],
555- dim = - 1 ,
556- )
557-
558589 if attn_metadata is None :
559590 # profile run
560591 hidden_states_B_C = (
561592 hidden_states_B_C .transpose (0 , 1 ).clone ().transpose (0 , 1 )
562593 ).contiguous ()
563- hidden_states , _B , _C = split_hidden_states_B_C_fn (hidden_states_B_C )
564- hidden_states = self .norm (hidden_states , gate )
565- out , _ = self .out_proj (hidden_states )
566- return out
594+ hidden_states , _B , _C = self .split_hidden_states_B_C_fn (hidden_states_B_C )
595+ return hidden_states
567596
568597 # NOTE: V0 put prefill before decode, v1 puts decode before prefill
569598 num_prefills = attn_metadata .num_prefills # request count
@@ -622,18 +651,8 @@ def forward_cuda(
622651 block_idx_first_scheduled_token_p = None
623652 num_computed_tokens_p = None
624653
625- # Preallocate output tensor to avoid memcpy cost for merging prefill
626- # and decode outputs
627- preallocated_ssm_out = torch .empty (
628- [
629- num_prefill_tokens + num_decodes ,
630- (self .num_heads // self .tp_size ) * self .head_dim ,
631- ],
632- dtype = hidden_states .dtype ,
633- device = hidden_states .device ,
634- )
635654 preallocated_ssm_out_d , preallocated_ssm_out_p = torch .split (
636- preallocated_ssm_out ,
655+ output [: num_actual_tokens ] ,
637656 [num_decodes , num_prefill_tokens ],
638657 dim = 0 ,
639658 )
@@ -658,7 +677,7 @@ def forward_cuda(
658677 ) # this is the form that causal-conv see
659678 hidden_states_B_C_p = causal_conv1d_fn (
660679 x ,
661- conv_weights ,
680+ self . conv_weights ,
662681 self .conv1d .bias ,
663682 activation = self .activation ,
664683 conv_states = conv_state ,
@@ -673,7 +692,9 @@ def forward_cuda(
673692 query_start_loc = query_start_loc_p ,
674693 ).transpose (0 , 1 )[:num_prefill_tokens ]
675694
676- hidden_states_p , B_p , C_p = split_hidden_states_B_C_fn (hidden_states_B_C_p )
695+ hidden_states_p , B_p , C_p = self .split_hidden_states_B_C_fn (
696+ hidden_states_B_C_p
697+ )
677698
678699 # 3. State Space Model sequence transformation
679700 initial_states = None
@@ -815,15 +836,17 @@ def forward_cuda(
815836 hidden_states_B_C_d = causal_conv1d_update (
816837 hidden_states_B_C_d ,
817838 conv_state ,
818- conv_weights ,
839+ self . conv_weights ,
819840 self .conv1d .bias ,
820841 self .activation ,
821842 conv_state_indices = state_indices_tensor_d ,
822843 block_idx_last_scheduled_token = block_idx_last_scheduled_token_d ,
823844 initial_state_idx = block_idx_last_computed_token_d ,
824845 )
825846
826- hidden_states_d , B_d , C_d = split_hidden_states_B_C_fn (hidden_states_B_C_d )
847+ hidden_states_d , B_d , C_d = self .split_hidden_states_B_C_fn (
848+ hidden_states_B_C_d
849+ )
827850
828851 # 3. State Space Model sequence transformation
829852 n_groups = self .n_groups // self .tp_size
@@ -861,15 +884,6 @@ def forward_cuda(
861884 out = preallocated_ssm_out_d .view (num_decodes , - 1 , self .head_dim ),
862885 )
863886
864- # 4. gated MLP
865- # GatedRMSNorm internally applying SiLU to the gate
866- # SiLU is applied internally before normalization, unlike standard
867- # norm usage
868- hidden_states = self .norm (preallocated_ssm_out , gate [:num_actual_tokens ])
869-
870- # 5. Final linear projection
871- output [:num_actual_tokens ], _ = self .out_proj (hidden_states )
872-
873887 def get_state_dtype (self ) -> tuple [torch .dtype , torch .dtype ]:
874888 assert self .model_config is not None
875889 assert self .cache_config is not None
@@ -901,21 +915,19 @@ def get_attn_backend(self) -> type["AttentionBackend"]:
901915
902916
903917def mamba_mixer2 (
904- hidden_states : torch .Tensor ,
918+ projected_states : torch .Tensor ,
905919 output : torch .Tensor ,
906920 layer_name : str ,
907- mup_vector : torch .Tensor | None = None ,
908921) -> None :
909922 forward_context : ForwardContext = get_forward_context ()
910923 self = forward_context .no_compile_layers [layer_name ]
911- self .forward_cuda ( hidden_states = hidden_states , output = output , mup_vector = mup_vector )
924+ self .conv_ssm_forward ( projected_states = projected_states , output = output )
912925
913926
914927def mamba_mixer2_fake (
915- hidden_states : torch .Tensor ,
928+ projected_states : torch .Tensor ,
916929 output : torch .Tensor ,
917930 layer_name : str ,
918- mup_vector : torch .Tensor | None = None ,
919931) -> None :
920932 return
921933
0 commit comments