Skip to content

Commit 8aa924f

Browse files
tomeras91gemini-code-assist[bot]
authored andcommitted
[Mamba] Support TP>1 with quantization for mamba2 mixer in case n_groups % tp_size == 0 (vllm-project#24593)
Signed-off-by: Tomer Asida <[email protected]> Signed-off-by: tomeras91 <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: xuebwang-amd <[email protected]>
1 parent a62f013 commit 8aa924f

File tree

1 file changed

+119
-84
lines changed

1 file changed

+119
-84
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 119 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vllm.forward_context import ForwardContext, get_forward_context
2020
from vllm.model_executor.custom_op import CustomOp
2121
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
22+
MergedColumnParallelLinear,
2223
RowParallelLinear)
2324
from vllm.model_executor.layers.mamba.abstract import MambaBase
2425
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
@@ -261,12 +262,14 @@ def __init__(self,
261262
), "Tensor parallel world size must divide num heads."
262263

263264
assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
264-
"If tensor parallel world size does not divide num_heads, "
265+
"If tensor parallel world size does not divide num_groups, "
265266
"then num_groups must equal 1.")
266267

267-
assert (
268-
self.tp_size == 1 or quant_config is None
269-
), "Tensor parallel currently not supported for quantized models."
268+
assert (n_groups % self.tp_size == 0) or self.tp_size == 1 or \
269+
quant_config is None, (
270+
"Tensor parallel currently supported for quantized models only "
271+
"if tensor parallel world size divides num groups."
272+
)
270273

271274
self.ssm_state_size = ssm_state_size
272275
self.conv_kernel_size = conv_kernel_size
@@ -285,101 +288,135 @@ def __init__(self,
285288
n_groups, self.tp_size)
286289
self.n_groups = n_groups + groups
287290

288-
self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size
289-
self.conv1d = ColumnParallelLinear(
290-
input_size=conv_kernel_size,
291-
output_size=self.conv_dim,
292-
bias=use_conv_bias,
293-
quant_config=None,
294-
prefix=f"{prefix}.conv1d",
295-
)
296-
# unsqueeze to fit conv1d weights shape into the linear weights shape.
297-
# Can't do this in `weight_loader` since it already exists in
298-
# `ColumnParallelLinear` and `set_weight_attrs`
299-
# doesn't allow to override it
300-
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
291+
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
292+
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
293+
294+
if n_groups % self.tp_size == 0:
295+
self.conv1d = MergedColumnParallelLinear(
296+
input_size=conv_kernel_size,
297+
output_sizes=[
298+
intermediate_size,
299+
self.groups_ssm_state_size,
300+
self.groups_ssm_state_size,
301+
],
302+
bias=use_conv_bias,
303+
quant_config=None,
304+
prefix=f"{prefix}.conv1d",
305+
)
301306

302-
self.in_proj = ColumnParallelLinear(
303-
input_size=hidden_size,
304-
output_size=intermediate_size + self.conv_dim + self.num_heads,
305-
bias=use_bias,
306-
quant_config=quant_config,
307-
prefix=f"{prefix}.in_proj",
308-
)
307+
self.in_proj = MergedColumnParallelLinear(
308+
input_size=hidden_size,
309+
output_sizes=[
310+
intermediate_size,
311+
intermediate_size,
312+
self.groups_ssm_state_size,
313+
self.groups_ssm_state_size,
314+
self.num_heads,
315+
],
316+
bias=use_bias,
317+
quant_config=quant_config,
318+
prefix=f"{prefix}.in_proj",
319+
)
320+
else:
321+
# This is the n_groups == 1 case,
322+
# where we need to duplicate groups if TP>1.
323+
324+
self.conv1d = ColumnParallelLinear(
325+
input_size=conv_kernel_size,
326+
output_size=self.conv_dim,
327+
bias=use_conv_bias,
328+
quant_config=None,
329+
prefix=f"{prefix}.conv1d",
330+
)
309331

310-
# - because in_proj is a concatenation of 3 weights, we
311-
# need to interleave them before sharding
312-
# - use the custom weight loader mamba_v2_sharded_weight_loader
313-
# for conv1d.bias, covn1d.weight and in_proj.weight
314-
# - need to set these settings, to assign the groups to the head shards
315-
group_shard_settings = (
316-
self.n_groups * self.ssm_state_size, # expected model size
317-
(self.n_groups - n_groups) *
318-
self.ssm_state_size, # extra dims assigned
319-
n_groups == 1, # if there was only one group
320-
)
321-
intermediate_settings = (intermediate_size, 0, False)
322-
head_settings = (self.num_heads, 0, False)
323-
324-
# - the weight already has a "weight_loader" attribute
325-
# which set_weight_attrs will raise if we do not
326-
# delete before trying to override it
327-
# - ditto for the other two weights below
328-
delattr(self.conv1d.bias, "weight_loader")
329-
set_weight_attrs(
330-
self.conv1d.bias,
331-
{
332-
"weight_loader":
333-
mamba_v2_sharded_weight_loader(
334-
[
335-
intermediate_settings,
336-
group_shard_settings,
337-
group_shard_settings,
338-
],
339-
self.tp_size,
340-
tp_rank,
341-
)
342-
},
343-
)
332+
self.in_proj = ColumnParallelLinear(
333+
input_size=hidden_size,
334+
output_size=intermediate_size + self.conv_dim + self.num_heads,
335+
bias=use_bias,
336+
quant_config=quant_config,
337+
prefix=f"{prefix}.in_proj",
338+
)
344339

345-
delattr(self.conv1d.weight, "weight_loader")
346-
set_weight_attrs(
347-
self.conv1d.weight,
348-
{
349-
"weight_loader":
350-
mamba_v2_sharded_weight_loader(
351-
[
352-
intermediate_settings,
353-
group_shard_settings,
354-
group_shard_settings,
355-
],
356-
self.tp_size,
357-
tp_rank,
358-
)
359-
},
360-
)
340+
# - because in_proj is a concatenation of 3 weights, we
341+
# need to interleave them before sharding
342+
# - use the custom weight loader mamba_v2_sharded_weight_loader
343+
# for conv1d.bias, covn1d.weight and in_proj.weight
344+
# - need to set these settings, to assign the groups
345+
# to the head shards
346+
group_shard_settings = (
347+
self.groups_ssm_state_size, # expected model size
348+
(self.n_groups - n_groups) *
349+
self.ssm_state_size, # extra dims assigned
350+
n_groups == 1, # if there was only one group
351+
)
352+
intermediate_settings = (intermediate_size, 0, False)
353+
head_settings = (self.num_heads, 0, False)
354+
355+
# - the weight already has a "weight_loader" attribute
356+
# which set_weight_attrs will raise if we do not
357+
# delete before trying to override it
358+
# - ditto for the other two weights below
359+
delattr(self.conv1d.bias, "weight_loader")
360+
set_weight_attrs(
361+
self.conv1d.bias,
362+
{
363+
"weight_loader":
364+
mamba_v2_sharded_weight_loader(
365+
[
366+
intermediate_settings,
367+
group_shard_settings,
368+
group_shard_settings,
369+
],
370+
self.tp_size,
371+
tp_rank,
372+
)
373+
},
374+
)
361375

362-
if quant_config is None:
363-
# - quant layers do not have a weight loader
364-
delattr(self.in_proj.weight, "weight_loader")
376+
delattr(self.conv1d.weight, "weight_loader")
365377
set_weight_attrs(
366-
self.in_proj.weight,
378+
self.conv1d.weight,
367379
{
368380
"weight_loader":
369381
mamba_v2_sharded_weight_loader(
370382
[
371-
intermediate_settings, # for gate
372383
intermediate_settings,
373384
group_shard_settings,
374385
group_shard_settings,
375-
head_settings, # for dt
376386
],
377387
self.tp_size,
378388
tp_rank,
379389
)
380390
},
381391
)
382392

393+
if quant_config is None:
394+
# - quant layers do not have a weight loader
395+
delattr(self.in_proj.weight, "weight_loader")
396+
set_weight_attrs(
397+
self.in_proj.weight,
398+
{
399+
"weight_loader":
400+
mamba_v2_sharded_weight_loader(
401+
[
402+
intermediate_settings, # for gate
403+
intermediate_settings,
404+
group_shard_settings,
405+
group_shard_settings,
406+
head_settings, # for dt
407+
],
408+
self.tp_size,
409+
tp_rank,
410+
)
411+
},
412+
)
413+
414+
# unsqueeze to fit conv1d weights shape into the linear weights shape.
415+
# Can't do this in `weight_loader` since it already exists in
416+
# `ColumnParallelLinear` and `MergedColumnParallelLinear`,
417+
# and `set_weight_attrs` doesn't allow to override it
418+
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
419+
383420
# - these are TPed by heads to reduce the size of the
384421
# temporal shape
385422
self.A = nn.Parameter(
@@ -498,8 +535,6 @@ def forward_cuda(
498535
chunk_indices_p = mamba2_metadata.chunk_indices
499536
chunk_offsets_p = mamba2_metadata.chunk_offsets
500537

501-
groups_time_state_size = self.n_groups * self.ssm_state_size
502-
503538
# 1. Gated MLP's linear projection
504539
projected_states, _ = self.in_proj(hidden_states)
505540

@@ -524,8 +559,8 @@ def forward_cuda(
524559
hidden_states_B_C,
525560
[
526561
self.intermediate_size // self.tp_size,
527-
groups_time_state_size // self.tp_size,
528-
groups_time_state_size // self.tp_size,
562+
self.groups_ssm_state_size // self.tp_size,
563+
self.groups_ssm_state_size // self.tp_size,
529564
],
530565
dim=-1,
531566
)

0 commit comments

Comments
 (0)