Skip to content

Commit d20b0c1

Browse files
Add patch merger (#14957)
1 parent 166a168 commit d20b0c1

File tree

4 files changed

+166
-8
lines changed

4 files changed

+166
-8
lines changed

requirements/common.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ pyzmq
2828
msgspec
2929
gguf == 0.10.0
3030
importlib_metadata
31-
mistral_common[opencv] >= 1.5.0
31+
mistral_common[opencv] >= 1.5.4
3232
pyyaml
3333
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
3434
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12

requirements/docs.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ pydantic >= 2.8
1515
torch
1616
py-cpuinfo
1717
transformers
18-
mistral_common >= 1.5.0
18+
mistral_common >= 1.5.4
1919
aiohttp
2020
starlette
2121
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args

requirements/test.in

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ torchaudio==2.6.0
2727
torchvision==0.21.0
2828
transformers_stream_generator # required for qwen-vl test
2929
matplotlib # required for qwen-vl test
30-
mistral_common[opencv] >= 1.5.0 # required for pixtral test
30+
mistral_common[opencv] >= 1.5.4 # required for pixtral test
3131
datamodel_code_generator # required for minicpm3 test
3232
lm-eval[api]==0.4.4 # required for model evaluation test
3333
transformers==4.48.2
@@ -40,4 +40,4 @@ tritonclient==2.51.0
4040

4141
numpy < 2.0.0
4242
runai-model-streamer==0.11.0
43-
runai-model-streamer-s3==0.11.0
43+
runai-model-streamer-s3==0.11.0

vllm/model_executor/models/pixtral.py

Lines changed: 162 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
except ImportError:
5757
USE_XFORMERS_OPS = False
5858

59+
PATCH_MERGE = "patch_merge"
60+
5961

6062
class PixtralImagePixelInputs(TypedDict):
6163
type: Literal["pixel_values"]
@@ -155,7 +157,6 @@ def __call__(
155157

156158
for image in images:
157159
image_inputs = self.image_processor(ImageChunk(image=image))
158-
159160
image_processed = torch.tensor(image_inputs.image)
160161
image_tokens = torch.tensor(image_inputs.tokens)
161162

@@ -353,6 +354,27 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
353354
)
354355

355356
self.vision_encoder = VisionTransformer(self.vision_args)
357+
358+
if self.vision_args.add_pre_mm_projector_layer_norm:
359+
self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size,
360+
eps=1e-5)
361+
362+
if self.vision_args.mm_projector_id == PATCH_MERGE:
363+
self.patch_merger = PatchMerger(
364+
vision_encoder_dim=self.vision_args.hidden_size,
365+
spatial_merge_size=self.vision_args.spatial_merge_size,
366+
use_mlp_bias=False,
367+
)
368+
if self.vision_args.add_pre_mm_projector_layer_norm:
369+
self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size,
370+
eps=1e-5)
371+
372+
if self.vision_args.mm_projector_id == PATCH_MERGE:
373+
self.patch_merger = PatchMerger(
374+
vision_encoder_dim=self.vision_args.hidden_size,
375+
spatial_merge_size=self.vision_args.spatial_merge_size,
376+
use_mlp_bias=False,
377+
)
356378
self.vision_language_adapter = VisionLanguageAdapter(
357379
self.vision_args, dim=config.text_config.hidden_size)
358380

@@ -398,13 +420,25 @@ def _process_image_input(
398420
image_input: PixtralImagePixelInputs,
399421
) -> tuple[torch.Tensor, ...]:
400422
images = image_input["images"]
401-
402423
image_features = self.vision_encoder(images)
403424
feature_sizes = [
404425
image_feature.shape[0] for image_feature in image_features
405426
]
406-
407-
image_embeds = self.vision_language_adapter(torch.cat(image_features))
427+
image_features = torch.cat(image_features)
428+
if self.vision_args.add_pre_mm_projector_layer_norm:
429+
image_features = self.pre_mm_projector_norm(image_features)
430+
if self.vision_args.mm_projector_id == PATCH_MERGE:
431+
patch_size = self.vision_args.patch_size
432+
spatial_merge_size_square = self.vision_args.spatial_merge_size**2
433+
img_patch_dims = [(img.shape[1] // patch_size,
434+
img.shape[2] // patch_size) for img in images]
435+
feature_sizes = [
436+
feature_size // spatial_merge_size_square
437+
for feature_size in feature_sizes
438+
]
439+
image_features = self.patch_merger(image_features,
440+
image_sizes=img_patch_dims)
441+
image_embeds = self.vision_language_adapter(image_features)
408442
image_embeds = torch.split(image_embeds, feature_sizes)
409443
return image_embeds
410444

@@ -524,8 +558,19 @@ def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
524558
def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
525559
return weight[0].startswith("vision_language_adapter")
526560

561+
def is_patch_merger(weight: Tuple[str, torch.Tensor]):
562+
return weight[0].startswith("patch_merger")
563+
564+
def is_pre_mm_projector_norm(weight: Tuple[str, torch.Tensor]):
565+
return weight[0].startswith("pre_mm_projector_norm")
566+
527567
# Get references to parameters for direct loading
528568
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
569+
patch_merger_dict = dict(self.patch_merger.named_parameters(
570+
)) if self.vision_args.mm_projector_id == PATCH_MERGE else dict()
571+
pre_mm_projector_norm_dict = dict(
572+
self.pre_mm_projector_norm.named_parameters(
573+
)) if self.vision_args.add_pre_mm_projector_layer_norm else dict()
529574
vision_lang_adapter_dict = dict(
530575
self.vision_language_adapter.named_parameters())
531576

@@ -538,6 +583,18 @@ def llm_weights_generator():
538583
param = vision_encoder_dict[trimmed_name]
539584
with torch.no_grad():
540585
default_weight_loader(param, w)
586+
elif is_patch_merger((name, w)):
587+
# Load vision patch merger weights directly
588+
trimmed_name = '.'.join(name.split(".")[1:])
589+
param = patch_merger_dict[trimmed_name]
590+
with torch.no_grad():
591+
default_weight_loader(param, w)
592+
elif is_pre_mm_projector_norm((name, w)):
593+
# Load vision pre_mm_projector_norm weights directly
594+
trimmed_name = '.'.join(name.split(".")[1:])
595+
param = pre_mm_projector_norm_dict[trimmed_name]
596+
with torch.no_grad():
597+
default_weight_loader(param, w)
541598
elif is_vision_lang_adapter_weights((name, w)):
542599
# Load vision-language adapter weights directly
543600
trimmed_name = '.'.join(name.split(".")[1:])
@@ -566,6 +623,9 @@ class VisionEncoderArgs:
566623
rope_theta: float # for rope-2D
567624
image_token_id: int
568625
adapter_bias: bool = True
626+
spatial_merge_size: int = 1
627+
add_pre_mm_projector_layer_norm: bool = False
628+
mm_projector_id: str = ""
569629

570630

571631
def _reshape_for_broadcast(freqs_cis: torch.Tensor,
@@ -843,6 +903,104 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
843903
return self.w_out(self.gelu(self.w_in(x)))
844904

845905

906+
class PatchMerger(nn.Module):
907+
"""
908+
Learned merging of spatial_merge_size ** 2 patches
909+
"""
910+
911+
def __init__(
912+
self,
913+
vision_encoder_dim: int,
914+
spatial_merge_size: int,
915+
use_mlp_bias: bool = False,
916+
) -> None:
917+
super().__init__()
918+
919+
mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2)
920+
921+
self.spatial_merge_size = spatial_merge_size
922+
self.mlp_input_dim = mlp_input_dim
923+
924+
self.merging_layer = nn.Linear(
925+
mlp_input_dim,
926+
vision_encoder_dim,
927+
bias=use_mlp_bias,
928+
)
929+
930+
def forward(self, x: torch.Tensor,
931+
image_sizes: list[tuple[int, int]]) -> torch.Tensor:
932+
# image_sizes specified in tokens
933+
assert sum([h * w for h, w in image_sizes]) == len(x)
934+
935+
# x is (N, vision_encoder_dim)
936+
x = self.permute(x, image_sizes)
937+
938+
# x is (N / spatial_merge_size ** 2, vision_encoder_dim * spatial_merge_size ** 2)
939+
x = self.merging_layer(x)
940+
941+
# x is (N / spatial_merge_size ** 2, vision_encoder_dim)
942+
return x
943+
944+
def permute(
945+
self,
946+
x: torch.Tensor,
947+
image_sizes: list[tuple[int, int]],
948+
) -> torch.Tensor:
949+
"""
950+
Args:
951+
x: (N, D) where N is flattened and concatenated patch tokens
952+
for all images
953+
image_sizes: list of tuple of (height, width) in tokens for
954+
each image
955+
Returns:
956+
image_features: reorders patch tokens so each grid of
957+
(spatial_merge_size, spatial_merge_size) is contiguous.
958+
now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2)
959+
"""
960+
961+
sub_grids = get_sub_grids(
962+
x=x,
963+
image_sizes=image_sizes,
964+
spatial_merge_size=self.spatial_merge_size
965+
) # list of [d x sub_grid_size x sub_grid_size x n_patches]
966+
permuted_tensor: list[torch.Tensor] = []
967+
for grid in sub_grids:
968+
n_patches = grid.shape[-1]
969+
permuted_tensor.append(grid.view(-1, n_patches).t(
970+
)) # n_patches x d * sub_grid_size * sub_grid_size
971+
return torch.cat(
972+
permuted_tensor, dim=0
973+
) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2)
974+
975+
976+
def get_sub_grids(
977+
x: torch.Tensor,
978+
image_sizes: list[tuple[int, int]],
979+
spatial_merge_size: int,
980+
) -> list[torch.Tensor]:
981+
# image_sizes specified in tokens
982+
tokens_per_image = [h * w for h, w in image_sizes]
983+
d = x.shape[-1]
984+
all_img_sub_grids: list[torch.Tensor] = []
985+
sub_grid_size = spatial_merge_size
986+
987+
for image_index, image_tokens in enumerate(x.split(tokens_per_image)):
988+
# Reshape image_tokens into a 2D grid
989+
h, w = image_sizes[image_index]
990+
image_grid = image_tokens.view(h, w, d).permute(
991+
2, 0, 1)[None, :, :, :] # 1 x d x h x w
992+
sub_grids = torch.nn.functional.unfold(image_grid,
993+
kernel_size=sub_grid_size,
994+
stride=sub_grid_size)
995+
sub_grids = sub_grids.view(
996+
1, d, sub_grid_size, sub_grid_size,
997+
-1) # 1 x d x sub_grid_size x sub_grid_size x n_patches
998+
999+
all_img_sub_grids.append(sub_grids[0])
1000+
1001+
return all_img_sub_grids
1002+
1003+
8461004
#### HF Transformers version of Pixtral ####
8471005
# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
8481006
# This model follows the Llava family, meaning image embeddings are placed

0 commit comments

Comments
 (0)