56
56
except ImportError :
57
57
USE_XFORMERS_OPS = False
58
58
59
+ PATCH_MERGE = "patch_merge"
60
+
59
61
60
62
class PixtralImagePixelInputs (TypedDict ):
61
63
type : Literal ["pixel_values" ]
@@ -155,7 +157,6 @@ def __call__(
155
157
156
158
for image in images :
157
159
image_inputs = self .image_processor (ImageChunk (image = image ))
158
-
159
160
image_processed = torch .tensor (image_inputs .image )
160
161
image_tokens = torch .tensor (image_inputs .tokens )
161
162
@@ -353,6 +354,27 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
353
354
)
354
355
355
356
self .vision_encoder = VisionTransformer (self .vision_args )
357
+
358
+ if self .vision_args .add_pre_mm_projector_layer_norm :
359
+ self .pre_mm_projector_norm = RMSNorm (self .vision_args .hidden_size ,
360
+ eps = 1e-5 )
361
+
362
+ if self .vision_args .mm_projector_id == PATCH_MERGE :
363
+ self .patch_merger = PatchMerger (
364
+ vision_encoder_dim = self .vision_args .hidden_size ,
365
+ spatial_merge_size = self .vision_args .spatial_merge_size ,
366
+ use_mlp_bias = False ,
367
+ )
368
+ if self .vision_args .add_pre_mm_projector_layer_norm :
369
+ self .pre_mm_projector_norm = RMSNorm (self .vision_args .hidden_size ,
370
+ eps = 1e-5 )
371
+
372
+ if self .vision_args .mm_projector_id == PATCH_MERGE :
373
+ self .patch_merger = PatchMerger (
374
+ vision_encoder_dim = self .vision_args .hidden_size ,
375
+ spatial_merge_size = self .vision_args .spatial_merge_size ,
376
+ use_mlp_bias = False ,
377
+ )
356
378
self .vision_language_adapter = VisionLanguageAdapter (
357
379
self .vision_args , dim = config .text_config .hidden_size )
358
380
@@ -398,13 +420,25 @@ def _process_image_input(
398
420
image_input : PixtralImagePixelInputs ,
399
421
) -> tuple [torch .Tensor , ...]:
400
422
images = image_input ["images" ]
401
-
402
423
image_features = self .vision_encoder (images )
403
424
feature_sizes = [
404
425
image_feature .shape [0 ] for image_feature in image_features
405
426
]
406
-
407
- image_embeds = self .vision_language_adapter (torch .cat (image_features ))
427
+ image_features = torch .cat (image_features )
428
+ if self .vision_args .add_pre_mm_projector_layer_norm :
429
+ image_features = self .pre_mm_projector_norm (image_features )
430
+ if self .vision_args .mm_projector_id == PATCH_MERGE :
431
+ patch_size = self .vision_args .patch_size
432
+ spatial_merge_size_square = self .vision_args .spatial_merge_size ** 2
433
+ img_patch_dims = [(img .shape [1 ] // patch_size ,
434
+ img .shape [2 ] // patch_size ) for img in images ]
435
+ feature_sizes = [
436
+ feature_size // spatial_merge_size_square
437
+ for feature_size in feature_sizes
438
+ ]
439
+ image_features = self .patch_merger (image_features ,
440
+ image_sizes = img_patch_dims )
441
+ image_embeds = self .vision_language_adapter (image_features )
408
442
image_embeds = torch .split (image_embeds , feature_sizes )
409
443
return image_embeds
410
444
@@ -524,8 +558,19 @@ def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
524
558
def is_vision_lang_adapter_weights (weight : Tuple [str , torch .Tensor ]):
525
559
return weight [0 ].startswith ("vision_language_adapter" )
526
560
561
+ def is_patch_merger (weight : Tuple [str , torch .Tensor ]):
562
+ return weight [0 ].startswith ("patch_merger" )
563
+
564
+ def is_pre_mm_projector_norm (weight : Tuple [str , torch .Tensor ]):
565
+ return weight [0 ].startswith ("pre_mm_projector_norm" )
566
+
527
567
# Get references to parameters for direct loading
528
568
vision_encoder_dict = dict (self .vision_encoder .named_parameters ())
569
+ patch_merger_dict = dict (self .patch_merger .named_parameters (
570
+ )) if self .vision_args .mm_projector_id == PATCH_MERGE else dict ()
571
+ pre_mm_projector_norm_dict = dict (
572
+ self .pre_mm_projector_norm .named_parameters (
573
+ )) if self .vision_args .add_pre_mm_projector_layer_norm else dict ()
529
574
vision_lang_adapter_dict = dict (
530
575
self .vision_language_adapter .named_parameters ())
531
576
@@ -538,6 +583,18 @@ def llm_weights_generator():
538
583
param = vision_encoder_dict [trimmed_name ]
539
584
with torch .no_grad ():
540
585
default_weight_loader (param , w )
586
+ elif is_patch_merger ((name , w )):
587
+ # Load vision patch merger weights directly
588
+ trimmed_name = '.' .join (name .split ("." )[1 :])
589
+ param = patch_merger_dict [trimmed_name ]
590
+ with torch .no_grad ():
591
+ default_weight_loader (param , w )
592
+ elif is_pre_mm_projector_norm ((name , w )):
593
+ # Load vision pre_mm_projector_norm weights directly
594
+ trimmed_name = '.' .join (name .split ("." )[1 :])
595
+ param = pre_mm_projector_norm_dict [trimmed_name ]
596
+ with torch .no_grad ():
597
+ default_weight_loader (param , w )
541
598
elif is_vision_lang_adapter_weights ((name , w )):
542
599
# Load vision-language adapter weights directly
543
600
trimmed_name = '.' .join (name .split ("." )[1 :])
@@ -566,6 +623,9 @@ class VisionEncoderArgs:
566
623
rope_theta : float # for rope-2D
567
624
image_token_id : int
568
625
adapter_bias : bool = True
626
+ spatial_merge_size : int = 1
627
+ add_pre_mm_projector_layer_norm : bool = False
628
+ mm_projector_id : str = ""
569
629
570
630
571
631
def _reshape_for_broadcast (freqs_cis : torch .Tensor ,
@@ -843,6 +903,104 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
843
903
return self .w_out (self .gelu (self .w_in (x )))
844
904
845
905
906
+ class PatchMerger (nn .Module ):
907
+ """
908
+ Learned merging of spatial_merge_size ** 2 patches
909
+ """
910
+
911
+ def __init__ (
912
+ self ,
913
+ vision_encoder_dim : int ,
914
+ spatial_merge_size : int ,
915
+ use_mlp_bias : bool = False ,
916
+ ) -> None :
917
+ super ().__init__ ()
918
+
919
+ mlp_input_dim = vision_encoder_dim * (spatial_merge_size ** 2 )
920
+
921
+ self .spatial_merge_size = spatial_merge_size
922
+ self .mlp_input_dim = mlp_input_dim
923
+
924
+ self .merging_layer = nn .Linear (
925
+ mlp_input_dim ,
926
+ vision_encoder_dim ,
927
+ bias = use_mlp_bias ,
928
+ )
929
+
930
+ def forward (self , x : torch .Tensor ,
931
+ image_sizes : list [tuple [int , int ]]) -> torch .Tensor :
932
+ # image_sizes specified in tokens
933
+ assert sum ([h * w for h , w in image_sizes ]) == len (x )
934
+
935
+ # x is (N, vision_encoder_dim)
936
+ x = self .permute (x , image_sizes )
937
+
938
+ # x is (N / spatial_merge_size ** 2, vision_encoder_dim * spatial_merge_size ** 2)
939
+ x = self .merging_layer (x )
940
+
941
+ # x is (N / spatial_merge_size ** 2, vision_encoder_dim)
942
+ return x
943
+
944
+ def permute (
945
+ self ,
946
+ x : torch .Tensor ,
947
+ image_sizes : list [tuple [int , int ]],
948
+ ) -> torch .Tensor :
949
+ """
950
+ Args:
951
+ x: (N, D) where N is flattened and concatenated patch tokens
952
+ for all images
953
+ image_sizes: list of tuple of (height, width) in tokens for
954
+ each image
955
+ Returns:
956
+ image_features: reorders patch tokens so each grid of
957
+ (spatial_merge_size, spatial_merge_size) is contiguous.
958
+ now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2)
959
+ """
960
+
961
+ sub_grids = get_sub_grids (
962
+ x = x ,
963
+ image_sizes = image_sizes ,
964
+ spatial_merge_size = self .spatial_merge_size
965
+ ) # list of [d x sub_grid_size x sub_grid_size x n_patches]
966
+ permuted_tensor : list [torch .Tensor ] = []
967
+ for grid in sub_grids :
968
+ n_patches = grid .shape [- 1 ]
969
+ permuted_tensor .append (grid .view (- 1 , n_patches ).t (
970
+ )) # n_patches x d * sub_grid_size * sub_grid_size
971
+ return torch .cat (
972
+ permuted_tensor , dim = 0
973
+ ) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2)
974
+
975
+
976
+ def get_sub_grids (
977
+ x : torch .Tensor ,
978
+ image_sizes : list [tuple [int , int ]],
979
+ spatial_merge_size : int ,
980
+ ) -> list [torch .Tensor ]:
981
+ # image_sizes specified in tokens
982
+ tokens_per_image = [h * w for h , w in image_sizes ]
983
+ d = x .shape [- 1 ]
984
+ all_img_sub_grids : list [torch .Tensor ] = []
985
+ sub_grid_size = spatial_merge_size
986
+
987
+ for image_index , image_tokens in enumerate (x .split (tokens_per_image )):
988
+ # Reshape image_tokens into a 2D grid
989
+ h , w = image_sizes [image_index ]
990
+ image_grid = image_tokens .view (h , w , d ).permute (
991
+ 2 , 0 , 1 )[None , :, :, :] # 1 x d x h x w
992
+ sub_grids = torch .nn .functional .unfold (image_grid ,
993
+ kernel_size = sub_grid_size ,
994
+ stride = sub_grid_size )
995
+ sub_grids = sub_grids .view (
996
+ 1 , d , sub_grid_size , sub_grid_size ,
997
+ - 1 ) # 1 x d x sub_grid_size x sub_grid_size x n_patches
998
+
999
+ all_img_sub_grids .append (sub_grids [0 ])
1000
+
1001
+ return all_img_sub_grids
1002
+
1003
+
846
1004
#### HF Transformers version of Pixtral ####
847
1005
# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
848
1006
# This model follows the Llava family, meaning image embeddings are placed
0 commit comments