@@ -115,6 +115,11 @@ def __init__(
115115 else :
116116 self .gate_proj = None
117117
118+ if merge and ap .mlp_patch_merger :
119+ self .patch_merger_proj = ExLlamaV2Linear (model , key + km ["patch_merger" ], in_features * merge ** 2 , in_features , ap .mlp_bias )
120+ self .submodules += [self .patch_merger_proj ]
121+ else :
122+ self .patch_merger_proj = None
118123
119124 def numel (self ) -> int :
120125
@@ -158,6 +163,9 @@ def load(
158163 if self .gate_proj is not None : self .gate_proj .load (device_context = device_context , output_map = down_map )
159164 self .up_proj .load (device_context = device_context , output_map = down_map )
160165
166+ if self .patch_merger_proj is not None :
167+ self .patch_merger_proj .load ()
168+
161169 if self .up_proj .is_quant ():
162170 assert self .gate_proj is None or self .gate_proj .is_quant ()
163171 assert self .up_proj .is_quant (), "Partially quantized MLP layer"
@@ -302,6 +310,8 @@ def set_device_idx(self, idx: int | None):
302310 if self .gate_proj is not None : self .gate_proj .set_device_idx (idx )
303311 self .up_proj .set_device_idx (idx )
304312 self .down_proj .set_device_idx (idx )
313+ if self .patch_merger_proj is not None :
314+ self .patch_merger_proj .set_device_idx (idx )
305315
306316
307317 # @profile
@@ -458,9 +468,18 @@ def forward_torch(
458468 if self .pre_layernorm else hidden_states
459469
460470 if self .merge :
461- bd = post_norm .shape [:- 2 ]
462- l , d = post_norm .shape [- 2 :]
463- post_norm = post_norm .view (* bd , l // self .merge , d * self .merge )
471+ if self .archparams .mlp_patch_merger :
472+ bsz = hidden_states .shape [0 ]
473+ assert bsz == 1
474+ (h , w ), d = kwargs ["patch_size" ], hidden_states .shape [- 1 ]
475+ image_grid = post_norm .view (h , w , d ).permute (2 , 0 , 1 ).unsqueeze (0 )
476+ grid = F .unfold (image_grid , kernel_size = int (self .merge ** 0.5 ), stride = int (self .merge ** 0.5 ))
477+ grid = grid .view (bsz , d * self .merge , - 1 ).transpose (1 , 2 )
478+ post_norm = self .patch_merger_proj .forward (grid )
479+ else :
480+ bd = post_norm .shape [:- 2 ]
481+ l , d = post_norm .shape [- 2 :]
482+ post_norm = post_norm .view (* bd , l // self .merge , d * self .merge )
464483
465484 if self .gate_proj is not None :
466485 gate = self .gate_proj .forward (post_norm , loras = loras )
0 commit comments