43
43
from vllm .distributed import utils as dist_utils
44
44
from vllm .logger import init_logger
45
45
from vllm .model_executor import SamplingMetadata
46
- from vllm .model_executor .layers .activation import _ACTIVATION_REGISTRY
46
+ from vllm .model_executor .layers .activation import get_act_and_mul_fn
47
47
from vllm .model_executor .layers .layernorm import RMSNorm
48
48
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
49
+ MergedColumnParallelLinear ,
49
50
QKVParallelLinear ,
50
51
RowParallelLinear )
51
52
from vllm .model_executor .layers .quantization import QuantizationConfig
@@ -171,16 +172,12 @@ def __init__(self,
171
172
quant_config : Optional [QuantizationConfig ] = None ,
172
173
prefix : str = "" ):
173
174
super ().__init__ ()
174
- self .gate_proj = ColumnParallelLinear (in_features ,
175
- hidden_features ,
176
- bias = bias ,
177
- quant_config = quant_config ,
178
- prefix = f"{ prefix } .gate_proj" )
179
- self .up_proj = ColumnParallelLinear (in_features ,
180
- hidden_features ,
181
- bias = bias ,
182
- quant_config = quant_config ,
183
- prefix = f"{ prefix } .up_proj" )
175
+ self .gate_up_proj = MergedColumnParallelLinear (
176
+ input_size = in_features ,
177
+ output_sizes = [hidden_features ] * 2 , # [gate_proj, up_proj]
178
+ bias = bias ,
179
+ quant_config = quant_config ,
180
+ prefix = f"{ prefix } .gate_up_proj" )
184
181
self .down_proj = RowParallelLinear (hidden_features ,
185
182
in_features ,
186
183
bias = bias ,
@@ -189,10 +186,9 @@ def __init__(self,
189
186
self .act_fn = act_fn
190
187
191
188
def forward (self , x : torch .Tensor ):
192
- x_gate , _ = self .gate_proj (x )
193
- x_gate = self .act_fn (x_gate )
194
- x_up , _ = self .up_proj (x )
195
- x_down , _ = self .down_proj (x_gate * x_up )
189
+ gate_up , _ = self .gate_up_proj (x )
190
+ x = self .act_fn (gate_up )
191
+ x_down , _ = self .down_proj (x )
196
192
return x_down
197
193
198
194
@@ -540,14 +536,14 @@ def __init__(
540
536
self .rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding (head_dim // 2 )
541
537
542
538
self .blocks = nn .ModuleList ([
543
- Qwen2_5_VisionBlock (
544
- dim = self .hidden_size ,
545
- num_heads = self . num_heads ,
546
- mlp_hidden_dim = vision_config . intermediate_size ,
547
- act_fn = _ACTIVATION_REGISTRY [ vision_config .hidden_act ] ,
548
- norm_layer = norm_layer ,
549
- quant_config = quant_config ,
550
- prefix = f"{ prefix } .blocks.{ layer_idx } " )
539
+ Qwen2_5_VisionBlock (dim = self . hidden_size ,
540
+ num_heads = self .num_heads ,
541
+ mlp_hidden_dim = vision_config . intermediate_size ,
542
+ act_fn = get_act_and_mul_fn (
543
+ vision_config .hidden_act ) ,
544
+ norm_layer = norm_layer ,
545
+ quant_config = quant_config ,
546
+ prefix = f"{ prefix } .blocks.{ layer_idx } " )
551
547
for layer_idx in range (depth )
552
548
])
553
549
self .merger = Qwen2_5_VisionPatchMerger (
@@ -752,6 +748,8 @@ def load_weights(self, weights: Iterable[tuple[str,
752
748
("attn.qkv." , "attn.q." , "q" ),
753
749
("attn.qkv." , "attn.k." , "k" ),
754
750
("attn.qkv." , "attn.v." , "v" ),
751
+ ("mlp.gate_up_proj." , "mlp.gate_proj." , 0 ),
752
+ ("mlp.gate_up_proj." , "mlp.up_proj." , 1 ),
755
753
]
756
754
params_dict = dict (self .named_parameters (remove_duplicate = False ))
757
755
loaded_params : set [str ] = set ()
0 commit comments