1919from vllm .forward_context import ForwardContext , get_forward_context
2020from vllm .model_executor .custom_op import CustomOp
2121from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
22+ MergedColumnParallelLinear ,
2223 RowParallelLinear )
2324from vllm .model_executor .layers .mamba .abstract import MambaBase
2425from vllm .model_executor .layers .mamba .mamba2_metadata import (Mamba2Metadata ,
@@ -261,12 +262,14 @@ def __init__(self,
261262 ), "Tensor parallel world size must divide num heads."
262263
263264 assert (n_groups % self .tp_size ) == 0 or n_groups == 1 , (
264- "If tensor parallel world size does not divide num_heads , "
265+ "If tensor parallel world size does not divide num_groups , "
265266 "then num_groups must equal 1." )
266267
267- assert (
268- self .tp_size == 1 or quant_config is None
269- ), "Tensor parallel currently not supported for quantized models."
268+ assert (n_groups % self .tp_size == 0 ) or self .tp_size == 1 or \
269+ quant_config is None , (
270+ "Tensor parallel currently supported for quantized models only "
271+ "if tensor parallel world size divides num groups."
272+ )
270273
271274 self .ssm_state_size = ssm_state_size
272275 self .conv_kernel_size = conv_kernel_size
@@ -285,101 +288,135 @@ def __init__(self,
285288 n_groups , self .tp_size )
286289 self .n_groups = n_groups + groups
287290
288- self .conv_dim = intermediate_size + 2 * self .n_groups * ssm_state_size
289- self .conv1d = ColumnParallelLinear (
290- input_size = conv_kernel_size ,
291- output_size = self .conv_dim ,
292- bias = use_conv_bias ,
293- quant_config = None ,
294- prefix = f"{ prefix } .conv1d" ,
295- )
296- # unsqueeze to fit conv1d weights shape into the linear weights shape.
297- # Can't do this in `weight_loader` since it already exists in
298- # `ColumnParallelLinear` and `set_weight_attrs`
299- # doesn't allow to override it
300- self .conv1d .weight .data = self .conv1d .weight .data .unsqueeze (1 )
291+ self .groups_ssm_state_size = self .n_groups * self .ssm_state_size
292+ self .conv_dim = intermediate_size + 2 * self .groups_ssm_state_size
293+
294+ if n_groups % self .tp_size == 0 :
295+ self .conv1d = MergedColumnParallelLinear (
296+ input_size = conv_kernel_size ,
297+ output_sizes = [
298+ intermediate_size ,
299+ self .groups_ssm_state_size ,
300+ self .groups_ssm_state_size ,
301+ ],
302+ bias = use_conv_bias ,
303+ quant_config = None ,
304+ prefix = f"{ prefix } .conv1d" ,
305+ )
301306
302- self .in_proj = ColumnParallelLinear (
303- input_size = hidden_size ,
304- output_size = intermediate_size + self .conv_dim + self .num_heads ,
305- bias = use_bias ,
306- quant_config = quant_config ,
307- prefix = f"{ prefix } .in_proj" ,
308- )
307+ self .in_proj = MergedColumnParallelLinear (
308+ input_size = hidden_size ,
309+ output_sizes = [
310+ intermediate_size ,
311+ intermediate_size ,
312+ self .groups_ssm_state_size ,
313+ self .groups_ssm_state_size ,
314+ self .num_heads ,
315+ ],
316+ bias = use_bias ,
317+ quant_config = quant_config ,
318+ prefix = f"{ prefix } .in_proj" ,
319+ )
320+ else :
321+ # This is the n_groups == 1 case,
322+ # where we need to duplicate groups if TP>1.
323+
324+ self .conv1d = ColumnParallelLinear (
325+ input_size = conv_kernel_size ,
326+ output_size = self .conv_dim ,
327+ bias = use_conv_bias ,
328+ quant_config = None ,
329+ prefix = f"{ prefix } .conv1d" ,
330+ )
309331
310- # - because in_proj is a concatenation of 3 weights, we
311- # need to interleave them before sharding
312- # - use the custom weight loader mamba_v2_sharded_weight_loader
313- # for conv1d.bias, covn1d.weight and in_proj.weight
314- # - need to set these settings, to assign the groups to the head shards
315- group_shard_settings = (
316- self .n_groups * self .ssm_state_size , # expected model size
317- (self .n_groups - n_groups ) *
318- self .ssm_state_size , # extra dims assigned
319- n_groups == 1 , # if there was only one group
320- )
321- intermediate_settings = (intermediate_size , 0 , False )
322- head_settings = (self .num_heads , 0 , False )
323-
324- # - the weight already has a "weight_loader" attribute
325- # which set_weight_attrs will raise if we do not
326- # delete before trying to override it
327- # - ditto for the other two weights below
328- delattr (self .conv1d .bias , "weight_loader" )
329- set_weight_attrs (
330- self .conv1d .bias ,
331- {
332- "weight_loader" :
333- mamba_v2_sharded_weight_loader (
334- [
335- intermediate_settings ,
336- group_shard_settings ,
337- group_shard_settings ,
338- ],
339- self .tp_size ,
340- tp_rank ,
341- )
342- },
343- )
332+ self .in_proj = ColumnParallelLinear (
333+ input_size = hidden_size ,
334+ output_size = intermediate_size + self .conv_dim + self .num_heads ,
335+ bias = use_bias ,
336+ quant_config = quant_config ,
337+ prefix = f"{ prefix } .in_proj" ,
338+ )
344339
345- delattr (self .conv1d .weight , "weight_loader" )
346- set_weight_attrs (
347- self .conv1d .weight ,
348- {
349- "weight_loader" :
350- mamba_v2_sharded_weight_loader (
351- [
352- intermediate_settings ,
353- group_shard_settings ,
354- group_shard_settings ,
355- ],
356- self .tp_size ,
357- tp_rank ,
358- )
359- },
360- )
340+ # - because in_proj is a concatenation of 3 weights, we
341+ # need to interleave them before sharding
342+ # - use the custom weight loader mamba_v2_sharded_weight_loader
343+ # for conv1d.bias, covn1d.weight and in_proj.weight
344+ # - need to set these settings, to assign the groups
345+ # to the head shards
346+ group_shard_settings = (
347+ self .groups_ssm_state_size , # expected model size
348+ (self .n_groups - n_groups ) *
349+ self .ssm_state_size , # extra dims assigned
350+ n_groups == 1 , # if there was only one group
351+ )
352+ intermediate_settings = (intermediate_size , 0 , False )
353+ head_settings = (self .num_heads , 0 , False )
354+
355+ # - the weight already has a "weight_loader" attribute
356+ # which set_weight_attrs will raise if we do not
357+ # delete before trying to override it
358+ # - ditto for the other two weights below
359+ delattr (self .conv1d .bias , "weight_loader" )
360+ set_weight_attrs (
361+ self .conv1d .bias ,
362+ {
363+ "weight_loader" :
364+ mamba_v2_sharded_weight_loader (
365+ [
366+ intermediate_settings ,
367+ group_shard_settings ,
368+ group_shard_settings ,
369+ ],
370+ self .tp_size ,
371+ tp_rank ,
372+ )
373+ },
374+ )
361375
362- if quant_config is None :
363- # - quant layers do not have a weight loader
364- delattr (self .in_proj .weight , "weight_loader" )
376+ delattr (self .conv1d .weight , "weight_loader" )
365377 set_weight_attrs (
366- self .in_proj .weight ,
378+ self .conv1d .weight ,
367379 {
368380 "weight_loader" :
369381 mamba_v2_sharded_weight_loader (
370382 [
371- intermediate_settings , # for gate
372383 intermediate_settings ,
373384 group_shard_settings ,
374385 group_shard_settings ,
375- head_settings , # for dt
376386 ],
377387 self .tp_size ,
378388 tp_rank ,
379389 )
380390 },
381391 )
382392
393+ if quant_config is None :
394+ # - quant layers do not have a weight loader
395+ delattr (self .in_proj .weight , "weight_loader" )
396+ set_weight_attrs (
397+ self .in_proj .weight ,
398+ {
399+ "weight_loader" :
400+ mamba_v2_sharded_weight_loader (
401+ [
402+ intermediate_settings , # for gate
403+ intermediate_settings ,
404+ group_shard_settings ,
405+ group_shard_settings ,
406+ head_settings , # for dt
407+ ],
408+ self .tp_size ,
409+ tp_rank ,
410+ )
411+ },
412+ )
413+
414+ # unsqueeze to fit conv1d weights shape into the linear weights shape.
415+ # Can't do this in `weight_loader` since it already exists in
416+ # `ColumnParallelLinear` and `MergedColumnParallelLinear`,
417+ # and `set_weight_attrs` doesn't allow to override it
418+ self .conv1d .weight .data = self .conv1d .weight .data .unsqueeze (1 )
419+
383420 # - these are TPed by heads to reduce the size of the
384421 # temporal shape
385422 self .A = nn .Parameter (
@@ -498,8 +535,6 @@ def forward_cuda(
498535 chunk_indices_p = mamba2_metadata .chunk_indices
499536 chunk_offsets_p = mamba2_metadata .chunk_offsets
500537
501- groups_time_state_size = self .n_groups * self .ssm_state_size
502-
503538 # 1. Gated MLP's linear projection
504539 projected_states , _ = self .in_proj (hidden_states )
505540
@@ -524,8 +559,8 @@ def forward_cuda(
524559 hidden_states_B_C ,
525560 [
526561 self .intermediate_size // self .tp_size ,
527- groups_time_state_size // self .tp_size ,
528- groups_time_state_size // self .tp_size ,
562+ self . groups_ssm_state_size // self .tp_size ,
563+ self . groups_ssm_state_size // self .tp_size ,
529564 ],
530565 dim = - 1 ,
531566 )
0 commit comments