Skip to content

Commit ee2eb6e

Browse files
vllmellmkliuae
andauthored
[Model] Qwen2.5 VL SiLU-and-Mul (#22066)
Signed-off-by: kf <[email protected]> Signed-off-by: vllmellm <[email protected]> Co-authored-by: kf <[email protected]>
1 parent 2332243 commit ee2eb6e

File tree

1 file changed

+21
-23
lines changed

1 file changed

+21
-23
lines changed

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@
4343
from vllm.distributed import utils as dist_utils
4444
from vllm.logger import init_logger
4545
from vllm.model_executor import SamplingMetadata
46-
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
46+
from vllm.model_executor.layers.activation import get_act_and_mul_fn
4747
from vllm.model_executor.layers.layernorm import RMSNorm
4848
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
49+
MergedColumnParallelLinear,
4950
QKVParallelLinear,
5051
RowParallelLinear)
5152
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -171,16 +172,12 @@ def __init__(self,
171172
quant_config: Optional[QuantizationConfig] = None,
172173
prefix: str = ""):
173174
super().__init__()
174-
self.gate_proj = ColumnParallelLinear(in_features,
175-
hidden_features,
176-
bias=bias,
177-
quant_config=quant_config,
178-
prefix=f"{prefix}.gate_proj")
179-
self.up_proj = ColumnParallelLinear(in_features,
180-
hidden_features,
181-
bias=bias,
182-
quant_config=quant_config,
183-
prefix=f"{prefix}.up_proj")
175+
self.gate_up_proj = MergedColumnParallelLinear(
176+
input_size=in_features,
177+
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
178+
bias=bias,
179+
quant_config=quant_config,
180+
prefix=f"{prefix}.gate_up_proj")
184181
self.down_proj = RowParallelLinear(hidden_features,
185182
in_features,
186183
bias=bias,
@@ -189,10 +186,9 @@ def __init__(self,
189186
self.act_fn = act_fn
190187

191188
def forward(self, x: torch.Tensor):
192-
x_gate, _ = self.gate_proj(x)
193-
x_gate = self.act_fn(x_gate)
194-
x_up, _ = self.up_proj(x)
195-
x_down, _ = self.down_proj(x_gate * x_up)
189+
gate_up, _ = self.gate_up_proj(x)
190+
x = self.act_fn(gate_up)
191+
x_down, _ = self.down_proj(x)
196192
return x_down
197193

198194

@@ -540,14 +536,14 @@ def __init__(
540536
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
541537

542538
self.blocks = nn.ModuleList([
543-
Qwen2_5_VisionBlock(
544-
dim=self.hidden_size,
545-
num_heads=self.num_heads,
546-
mlp_hidden_dim=vision_config.intermediate_size,
547-
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
548-
norm_layer=norm_layer,
549-
quant_config=quant_config,
550-
prefix=f"{prefix}.blocks.{layer_idx}")
539+
Qwen2_5_VisionBlock(dim=self.hidden_size,
540+
num_heads=self.num_heads,
541+
mlp_hidden_dim=vision_config.intermediate_size,
542+
act_fn=get_act_and_mul_fn(
543+
vision_config.hidden_act),
544+
norm_layer=norm_layer,
545+
quant_config=quant_config,
546+
prefix=f"{prefix}.blocks.{layer_idx}")
551547
for layer_idx in range(depth)
552548
])
553549
self.merger = Qwen2_5_VisionPatchMerger(
@@ -752,6 +748,8 @@ def load_weights(self, weights: Iterable[tuple[str,
752748
("attn.qkv.", "attn.q.", "q"),
753749
("attn.qkv.", "attn.k.", "k"),
754750
("attn.qkv.", "attn.v.", "v"),
751+
("mlp.gate_up_proj.", "mlp.gate_proj.", 0),
752+
("mlp.gate_up_proj.", "mlp.up_proj.", 1),
755753
]
756754
params_dict = dict(self.named_parameters(remove_duplicate=False))
757755
loaded_params: set[str] = set()

0 commit comments

Comments
 (0)