5656except ImportError :
5757 USE_XFORMERS_OPS = False
5858
59+ PATCH_MERGE = "patch_merge"
60+
5961
6062class PixtralImagePixelInputs (TypedDict ):
6163 type : Literal ["pixel_values" ]
@@ -155,7 +157,6 @@ def __call__(
155157
156158 for image in images :
157159 image_inputs = self .image_processor (ImageChunk (image = image ))
158-
159160 image_processed = torch .tensor (image_inputs .image )
160161 image_tokens = torch .tensor (image_inputs .tokens )
161162
@@ -353,6 +354,27 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
353354 )
354355
355356 self .vision_encoder = VisionTransformer (self .vision_args )
357+
358+ if self .vision_args .add_pre_mm_projector_layer_norm :
359+ self .pre_mm_projector_norm = RMSNorm (self .vision_args .hidden_size ,
360+ eps = 1e-5 )
361+
362+ if self .vision_args .mm_projector_id == PATCH_MERGE :
363+ self .patch_merger = PatchMerger (
364+ vision_encoder_dim = self .vision_args .hidden_size ,
365+ spatial_merge_size = self .vision_args .spatial_merge_size ,
366+ use_mlp_bias = False ,
367+ )
368+ if self .vision_args .add_pre_mm_projector_layer_norm :
369+ self .pre_mm_projector_norm = RMSNorm (self .vision_args .hidden_size ,
370+ eps = 1e-5 )
371+
372+ if self .vision_args .mm_projector_id == PATCH_MERGE :
373+ self .patch_merger = PatchMerger (
374+ vision_encoder_dim = self .vision_args .hidden_size ,
375+ spatial_merge_size = self .vision_args .spatial_merge_size ,
376+ use_mlp_bias = False ,
377+ )
356378 self .vision_language_adapter = VisionLanguageAdapter (
357379 self .vision_args , dim = config .text_config .hidden_size )
358380
@@ -398,13 +420,25 @@ def _process_image_input(
398420 image_input : PixtralImagePixelInputs ,
399421 ) -> tuple [torch .Tensor , ...]:
400422 images = image_input ["images" ]
401-
402423 image_features = self .vision_encoder (images )
403424 feature_sizes = [
404425 image_feature .shape [0 ] for image_feature in image_features
405426 ]
406-
407- image_embeds = self .vision_language_adapter (torch .cat (image_features ))
427+ image_features = torch .cat (image_features )
428+ if self .vision_args .add_pre_mm_projector_layer_norm :
429+ image_features = self .pre_mm_projector_norm (image_features )
430+ if self .vision_args .mm_projector_id == PATCH_MERGE :
431+ patch_size = self .vision_args .patch_size
432+ spatial_merge_size_square = self .vision_args .spatial_merge_size ** 2
433+ img_patch_dims = [(img .shape [1 ] // patch_size ,
434+ img .shape [2 ] // patch_size ) for img in images ]
435+ feature_sizes = [
436+ feature_size // spatial_merge_size_square
437+ for feature_size in feature_sizes
438+ ]
439+ image_features = self .patch_merger (image_features ,
440+ image_sizes = img_patch_dims )
441+ image_embeds = self .vision_language_adapter (image_features )
408442 image_embeds = torch .split (image_embeds , feature_sizes )
409443 return image_embeds
410444
@@ -524,8 +558,19 @@ def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
524558 def is_vision_lang_adapter_weights (weight : Tuple [str , torch .Tensor ]):
525559 return weight [0 ].startswith ("vision_language_adapter" )
526560
561+ def is_patch_merger (weight : Tuple [str , torch .Tensor ]):
562+ return weight [0 ].startswith ("patch_merger" )
563+
564+ def is_pre_mm_projector_norm (weight : Tuple [str , torch .Tensor ]):
565+ return weight [0 ].startswith ("pre_mm_projector_norm" )
566+
527567 # Get references to parameters for direct loading
528568 vision_encoder_dict = dict (self .vision_encoder .named_parameters ())
569+ patch_merger_dict = dict (self .patch_merger .named_parameters (
570+ )) if self .vision_args .mm_projector_id == PATCH_MERGE else dict ()
571+ pre_mm_projector_norm_dict = dict (
572+ self .pre_mm_projector_norm .named_parameters (
573+ )) if self .vision_args .add_pre_mm_projector_layer_norm else dict ()
529574 vision_lang_adapter_dict = dict (
530575 self .vision_language_adapter .named_parameters ())
531576
@@ -538,6 +583,18 @@ def llm_weights_generator():
538583 param = vision_encoder_dict [trimmed_name ]
539584 with torch .no_grad ():
540585 default_weight_loader (param , w )
586+ elif is_patch_merger ((name , w )):
587+ # Load vision patch merger weights directly
588+ trimmed_name = '.' .join (name .split ("." )[1 :])
589+ param = patch_merger_dict [trimmed_name ]
590+ with torch .no_grad ():
591+ default_weight_loader (param , w )
592+ elif is_pre_mm_projector_norm ((name , w )):
593+ # Load vision pre_mm_projector_norm weights directly
594+ trimmed_name = '.' .join (name .split ("." )[1 :])
595+ param = pre_mm_projector_norm_dict [trimmed_name ]
596+ with torch .no_grad ():
597+ default_weight_loader (param , w )
541598 elif is_vision_lang_adapter_weights ((name , w )):
542599 # Load vision-language adapter weights directly
543600 trimmed_name = '.' .join (name .split ("." )[1 :])
@@ -566,6 +623,9 @@ class VisionEncoderArgs:
566623 rope_theta : float # for rope-2D
567624 image_token_id : int
568625 adapter_bias : bool = True
626+ spatial_merge_size : int = 1
627+ add_pre_mm_projector_layer_norm : bool = False
628+ mm_projector_id : str = ""
569629
570630
571631def _reshape_for_broadcast (freqs_cis : torch .Tensor ,
@@ -843,6 +903,104 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
843903 return self .w_out (self .gelu (self .w_in (x )))
844904
845905
906+ class PatchMerger (nn .Module ):
907+ """
908+ Learned merging of spatial_merge_size ** 2 patches
909+ """
910+
911+ def __init__ (
912+ self ,
913+ vision_encoder_dim : int ,
914+ spatial_merge_size : int ,
915+ use_mlp_bias : bool = False ,
916+ ) -> None :
917+ super ().__init__ ()
918+
919+ mlp_input_dim = vision_encoder_dim * (spatial_merge_size ** 2 )
920+
921+ self .spatial_merge_size = spatial_merge_size
922+ self .mlp_input_dim = mlp_input_dim
923+
924+ self .merging_layer = nn .Linear (
925+ mlp_input_dim ,
926+ vision_encoder_dim ,
927+ bias = use_mlp_bias ,
928+ )
929+
930+ def forward (self , x : torch .Tensor ,
931+ image_sizes : list [tuple [int , int ]]) -> torch .Tensor :
932+ # image_sizes specified in tokens
933+ assert sum ([h * w for h , w in image_sizes ]) == len (x )
934+
935+ # x is (N, vision_encoder_dim)
936+ x = self .permute (x , image_sizes )
937+
938+ # x is (N / spatial_merge_size ** 2, vision_encoder_dim * spatial_merge_size ** 2)
939+ x = self .merging_layer (x )
940+
941+ # x is (N / spatial_merge_size ** 2, vision_encoder_dim)
942+ return x
943+
944+ def permute (
945+ self ,
946+ x : torch .Tensor ,
947+ image_sizes : list [tuple [int , int ]],
948+ ) -> torch .Tensor :
949+ """
950+ Args:
951+ x: (N, D) where N is flattened and concatenated patch tokens
952+ for all images
953+ image_sizes: list of tuple of (height, width) in tokens for
954+ each image
955+ Returns:
956+ image_features: reorders patch tokens so each grid of
957+ (spatial_merge_size, spatial_merge_size) is contiguous.
958+ now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2)
959+ """
960+
961+ sub_grids = get_sub_grids (
962+ x = x ,
963+ image_sizes = image_sizes ,
964+ spatial_merge_size = self .spatial_merge_size
965+ ) # list of [d x sub_grid_size x sub_grid_size x n_patches]
966+ permuted_tensor : list [torch .Tensor ] = []
967+ for grid in sub_grids :
968+ n_patches = grid .shape [- 1 ]
969+ permuted_tensor .append (grid .view (- 1 , n_patches ).t (
970+ )) # n_patches x d * sub_grid_size * sub_grid_size
971+ return torch .cat (
972+ permuted_tensor , dim = 0
973+ ) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2)
974+
975+
976+ def get_sub_grids (
977+ x : torch .Tensor ,
978+ image_sizes : list [tuple [int , int ]],
979+ spatial_merge_size : int ,
980+ ) -> list [torch .Tensor ]:
981+ # image_sizes specified in tokens
982+ tokens_per_image = [h * w for h , w in image_sizes ]
983+ d = x .shape [- 1 ]
984+ all_img_sub_grids : list [torch .Tensor ] = []
985+ sub_grid_size = spatial_merge_size
986+
987+ for image_index , image_tokens in enumerate (x .split (tokens_per_image )):
988+ # Reshape image_tokens into a 2D grid
989+ h , w = image_sizes [image_index ]
990+ image_grid = image_tokens .view (h , w , d ).permute (
991+ 2 , 0 , 1 )[None , :, :, :] # 1 x d x h x w
992+ sub_grids = torch .nn .functional .unfold (image_grid ,
993+ kernel_size = sub_grid_size ,
994+ stride = sub_grid_size )
995+ sub_grids = sub_grids .view (
996+ 1 , d , sub_grid_size , sub_grid_size ,
997+ - 1 ) # 1 x d x sub_grid_size x sub_grid_size x n_patches
998+
999+ all_img_sub_grids .append (sub_grids [0 ])
1000+
1001+ return all_img_sub_grids
1002+
1003+
8461004#### HF Transformers version of Pixtral ####
8471005# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
8481006# This model follows the Llava family, meaning image embeddings are placed
0 commit comments