Skip to content

Commit ffa00c6

Browse files
[Model] Use merge_by_field_config for MM models (R-T)
Signed-off-by: Ayush Satyam <[email protected]>
1 parent 512b8af commit ffa00c6

File tree

4 files changed

+22
-24
lines changed

4 files changed

+22
-24
lines changed

vllm/model_executor/models/step3_vl.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
from .utils import (
5050
AutoWeightsLoader,
5151
WeightsMapper,
52-
flatten_bn,
5352
init_vllm_registered_model,
5453
maybe_prefix,
5554
)
@@ -895,6 +894,8 @@ def forward(
895894
dummy_inputs=Step3VLDummyInputsBuilder,
896895
)
897896
class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
897+
merge_by_field_config = True
898+
898899
hf_to_vllm_mapper = WeightsMapper(
899900
orig_to_new_prefix={
900901
"model.": "language_model.model.",
@@ -982,18 +983,25 @@ def _parse_and_validate_image_input(
982983
return None
983984

984985
if pixel_values is not None:
985-
pixel_values = flatten_bn(pixel_values, concat=True)
986+
if isinstance(pixel_values, list):
987+
pixel_values = torch.cat(pixel_values)
986988
if pixel_values.dim() >= 3:
987989
pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
988990
if patch_pixel_values is not None:
989-
patch_pixel_values = flatten_bn(patch_pixel_values, concat=True)
991+
if isinstance(patch_pixel_values, list):
992+
patch_pixel_values = torch.cat(patch_pixel_values)
990993
patch_pixel_values = patch_pixel_values.view(
991994
-1, *patch_pixel_values.shape[-3:]
992995
)
993996
# Handle empty patch_pixel_values by setting to None
994997
if patch_pixel_values.shape[0] == 0:
995998
patch_pixel_values = None
996-
num_patches = flatten_bn(num_patches, concat=True).tolist()
999+
if isinstance(num_patches, torch.Tensor):
1000+
num_patches = num_patches.tolist()
1001+
elif isinstance(num_patches, list):
1002+
num_patches = [
1003+
n.item() if isinstance(n, torch.Tensor) else n for n in num_patches
1004+
]
9971005

9981006
return Step3VLImagePixelInputs(
9991007
type="pixel_values",

vllm/model_executor/models/tarsier.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,7 @@
4747
from .clip import CLIPVisionModel
4848
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
4949
from .siglip import SiglipVisionModel
50-
from .utils import (
51-
AutoWeightsLoader,
52-
flatten_bn,
53-
init_vllm_registered_model,
54-
maybe_prefix,
55-
)
50+
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
5651
from .vision import (
5752
VisionEncoderInfo,
5853
get_num_selected_vision_tokens,
@@ -404,6 +399,8 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers_total: int) ->
404399
dummy_inputs=TarsierDummyInputsBuilder,
405400
)
406401
class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
402+
merge_by_field_config = True
403+
407404
packed_modules_mapping = {
408405
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
409406
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -472,9 +469,11 @@ def _parse_and_validate_image_input(
472469
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
473470
)
474471

472+
if isinstance(pixel_values, list):
473+
pixel_values = torch.cat(pixel_values)
475474
return TarsierImagePixelInputs(
476475
type="pixel_values",
477-
pixel_values=flatten_bn(pixel_values, concat=True),
476+
pixel_values=pixel_values,
478477
)
479478

480479
if image_embeds is not None:
@@ -483,9 +482,11 @@ def _parse_and_validate_image_input(
483482
"Incorrect type of image embeddings. "
484483
f"Got type: {type(image_embeds)}"
485484
)
485+
if isinstance(image_embeds, list):
486+
image_embeds = torch.cat(image_embeds)
486487
return TarsierImageEmbeddingInputs(
487488
type="image_embeds",
488-
data=flatten_bn(image_embeds, concat=True),
489+
data=image_embeds,
489490
)
490491

491492
raise AssertionError("This line should be unreachable.")

vllm/model_executor/models/terratorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def apply(
226226
dummy_inputs=TerratorchInputBuilder,
227227
)
228228
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
229+
merge_by_field_config = True
229230
supports_multimodal_raw_input_only = True
230231
is_pooling_model = True
231232

vllm/model_executor/models/transformers.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@
7979
AutoWeightsLoader,
8080
PPMissingLayer,
8181
WeightsMapper,
82-
flatten_bn,
8382
make_empty_intermediate_tensors_factory,
8483
maybe_prefix,
8584
)
@@ -842,17 +841,6 @@ def compute_logits(
842841
return logits
843842

844843

845-
def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
846-
"""Flatten until a list of tensors can be concatenated then do concat"""
847-
848-
def _can_concat(x: list[torch.Tensor]):
849-
return len(set(map(lambda _x: _x.shape[1:], x))) == 1
850-
851-
if _can_concat(x):
852-
return torch.concat(x)
853-
return flatten_and_concat(flatten_bn(x))
854-
855-
856844
@MULTIMODAL_REGISTRY.register_processor(
857845
MultiModalProcessor,
858846
info=MultiModalProcessingInfo,

0 commit comments

Comments
 (0)