Skip to content

Commit de34258

Browse files
[Model] Define merge_by_field_config MM interface (R-T) (#26260)
Signed-off-by: Ayush Satyam <[email protected]> Signed-off-by: DarkLight1337 <[email protected]> Co-authored-by: DarkLight1337 <[email protected]>
1 parent 185d8ed commit de34258

File tree

3 files changed

+46
-45
lines changed

3 files changed

+46
-45
lines changed

vllm/model_executor/models/step3_vl.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Iterable, Mapping, Sequence
55
from itertools import product
66
from math import ceil, sqrt
7-
from typing import Any, Literal, Optional, TypedDict, Union
7+
from typing import Annotated, Any, Literal, Optional, Union
88

99
import numpy as np
1010
import torch
@@ -44,28 +44,48 @@
4444
from vllm.sequence import IntermediateTensors
4545
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
4646
from vllm.transformers_utils.tokenizer import AnyTokenizer
47+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
4748

4849
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
4950
from .utils import (
5051
AutoWeightsLoader,
5152
WeightsMapper,
52-
flatten_bn,
5353
init_vllm_registered_model,
5454
maybe_prefix,
5555
)
5656
from .vision import run_dp_sharded_vision_model
5757

5858

59-
class Step3VLImagePixelInputs(TypedDict):
59+
class Step3VLImagePixelInputs(TensorSchema):
60+
"""
61+
Dimensions:
62+
- bn: Batch size * number of images
63+
- c: Number of channels (3)
64+
- h: Height
65+
- w: Width
66+
- bnp: Batch size * number of images * number of patches
67+
- hp: Height of patch
68+
- wp: Width of patch
69+
"""
70+
6071
type: Literal["pixel_values"]
61-
pixel_values: torch.Tensor
62-
patch_pixel_values: Optional[torch.Tensor]
63-
num_patches: list[int]
72+
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
73+
patch_pixel_values: Annotated[
74+
Optional[torch.Tensor], TensorShape("bnp", 3, "hp", "wp")
75+
]
76+
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
77+
6478

79+
class Step3VLImageEmbeddingInputs(TensorSchema):
80+
"""
81+
Dimensions:
82+
- bn: Batch size * number of images
83+
- f: Image feature size
84+
- h: Hidden size (must match the hidden size of language model backbone)
85+
"""
6586

66-
class Step3VLImageEmbeddingInputs(TypedDict):
67-
type: Literal["image_embeds"]
68-
image_embeds: torch.Tensor
87+
type: Literal["image_embeds"] = "image_embeds"
88+
data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]
6989

7090

7191
Step3VLImageInputs = Union[Step3VLImagePixelInputs, Step3VLImageEmbeddingInputs]
@@ -895,6 +915,8 @@ def forward(
895915
dummy_inputs=Step3VLDummyInputsBuilder,
896916
)
897917
class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
918+
merge_by_field_config = True
919+
898920
hf_to_vllm_mapper = WeightsMapper(
899921
orig_to_new_prefix={
900922
"model.": "language_model.model.",
@@ -982,41 +1004,22 @@ def _parse_and_validate_image_input(
9821004
return None
9831005

9841006
if pixel_values is not None:
985-
pixel_values = flatten_bn(pixel_values, concat=True)
986-
if pixel_values.dim() >= 3:
987-
pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
988-
if patch_pixel_values is not None:
989-
patch_pixel_values = flatten_bn(patch_pixel_values, concat=True)
990-
patch_pixel_values = patch_pixel_values.view(
991-
-1, *patch_pixel_values.shape[-3:]
992-
)
993-
# Handle empty patch_pixel_values by setting to None
994-
if patch_pixel_values.shape[0] == 0:
995-
patch_pixel_values = None
996-
num_patches = flatten_bn(num_patches, concat=True).tolist()
997-
9981007
return Step3VLImagePixelInputs(
9991008
type="pixel_values",
1000-
pixel_values=pixel_values.to(self.dtype).to(self.device),
1001-
patch_pixel_values=patch_pixel_values.to(self.dtype).to(self.device)
1009+
pixel_values=pixel_values.to(self.dtype),
1010+
patch_pixel_values=patch_pixel_values.to(self.dtype)
10021011
if patch_pixel_values is not None
10031012
else None,
10041013
num_patches=num_patches,
10051014
)
10061015

10071016
if image_embeds is not None:
1008-
if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
1009-
image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
1010-
else:
1011-
raise ValueError(
1012-
f"Unexpected shape for image_embeds: {image_embeds.shape}"
1013-
)
1014-
10151017
return Step3VLImageEmbeddingInputs(
10161018
type="image_embeds",
1017-
image_embeds=image_embeds.to(self.dtype).to(self.device),
1019+
image_embeds=image_embeds.to(self.dtype),
10181020
)
1019-
return None
1021+
1022+
raise AssertionError("This line should be unreachable.")
10201023

10211024
def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
10221025
B, P = image_features.shape[:2]

vllm/model_executor/models/tarsier.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,7 @@
4747
from .clip import CLIPVisionModel
4848
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
4949
from .siglip import SiglipVisionModel
50-
from .utils import (
51-
AutoWeightsLoader,
52-
init_vllm_registered_model,
53-
maybe_prefix,
54-
)
50+
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
5551
from .vision import (
5652
VisionEncoderInfo,
5753
get_num_selected_vision_tokens,

vllm/model_executor/models/terratorch.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,10 @@ def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]):
8787
if input.type == InputTypeEnum.tensor:
8888
fields[input_name] = "image"
8989

90-
mm_fields_config = {}
91-
for field_name, field_modality in fields.items():
92-
mm_fields_config[field_name] = MultiModalFieldConfig.shared(
93-
batch_size=1, modality=field_modality
94-
)
95-
return mm_fields_config
90+
return {
91+
field_name: MultiModalFieldConfig.batched(modality=field_modality)
92+
for field_name, field_modality in fields.items()
93+
}
9694

9795
return _terratorch_field_config
9896

@@ -192,9 +190,12 @@ def apply(
192190
) -> MultiModalInputs:
193191
if "image" in mm_data:
194192
image_data = mm_data["image"]
193+
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
195194
else:
196195
image_data = mm_data
197-
mm_data = {"image": mm_data}
196+
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
197+
198+
mm_data = {"image": image_data}
198199

199200
mm_items = self._to_mm_items(mm_data)
200201
tokenization_kwargs = tokenization_kwargs or {}
@@ -226,6 +227,7 @@ def apply(
226227
dummy_inputs=TerratorchInputBuilder,
227228
)
228229
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
230+
merge_by_field_config = True
229231
supports_multimodal_raw_input_only = True
230232
is_pooling_model = True
231233

0 commit comments

Comments
 (0)