Skip to content

Commit f115476

Browse files
refactor(multimodal R-T): Migrate MM models to merge_by_field_config
Migrate step3_vl, tarsier, terratorch, ultravox, voxtral, and whisper to use merge_by_field_config = True, enabling HF-compatible input shapes. Remove flatten_bn calls and dead flatten_and_concat function. Signed-off-by: Ayush Satyam <[email protected]>
1 parent 432e1cb commit f115476

File tree

4 files changed

+16
-22
lines changed

4 files changed

+16
-22
lines changed

vllm/model_executor/models/step3_vl.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from vllm.transformers_utils.tokenizer import AnyTokenizer
3838

3939
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
40-
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
40+
from .utils import (AutoWeightsLoader, WeightsMapper,
4141
init_vllm_registered_model, maybe_prefix)
4242
from .vision import run_dp_sharded_vision_model
4343

@@ -836,6 +836,7 @@ def forward(
836836
dummy_inputs=Step3VLDummyInputsBuilder)
837837
class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
838838
SupportsPP):
839+
merge_by_field_config = True
839840

840841
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
841842
"model.": "language_model.model.",
@@ -917,18 +918,21 @@ def _parse_and_validate_image_input(
917918
return None
918919

919920
if pixel_values is not None:
920-
pixel_values = flatten_bn(pixel_values, concat=True)
921921
if pixel_values.dim() >= 3:
922922
pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
923923
if patch_pixel_values is not None:
924-
patch_pixel_values = flatten_bn(patch_pixel_values,
925-
concat=True)
926924
patch_pixel_values = patch_pixel_values.view(
927925
-1, *patch_pixel_values.shape[-3:])
928926
# Handle empty patch_pixel_values by setting to None
929927
if patch_pixel_values.shape[0] == 0:
930928
patch_pixel_values = None
931-
num_patches = flatten_bn(num_patches, concat=True).tolist()
929+
if isinstance(num_patches, torch.Tensor):
930+
num_patches = num_patches.tolist()
931+
elif isinstance(num_patches, list):
932+
num_patches = [
933+
n.item() if isinstance(n, torch.Tensor) else n
934+
for n in num_patches
935+
]
932936

933937
return Step3VLImagePixelInputs(
934938
type="pixel_values",

vllm/model_executor/models/tarsier.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@
3838
from .clip import CLIPVisionModel
3939
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
4040
from .siglip import SiglipVisionModel
41-
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
42-
maybe_prefix)
41+
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
4342
from .vision import (VisionEncoderInfo, get_num_selected_vision_tokens,
4443
get_vision_encoder_info)
4544

@@ -386,6 +385,8 @@ def _get_layer_index(feature_layer_index: int,
386385
dummy_inputs=TarsierDummyInputsBuilder)
387386
class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
388387
SupportsPP):
388+
merge_by_field_config = True
389+
389390
packed_modules_mapping = {
390391
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
391392
"gate_up_proj": ["gate_proj", "up_proj"]
@@ -450,7 +451,7 @@ def _parse_and_validate_image_input(
450451

451452
return TarsierImagePixelInputs(
452453
type="pixel_values",
453-
pixel_values=flatten_bn(pixel_values, concat=True),
454+
pixel_values=pixel_values,
454455
)
455456

456457
if image_embeds is not None:
@@ -459,7 +460,7 @@ def _parse_and_validate_image_input(
459460
f"Got type: {type(image_embeds)}")
460461
return TarsierImageEmbeddingInputs(
461462
type="image_embeds",
462-
data=flatten_bn(image_embeds, concat=True),
463+
data=image_embeds,
463464
)
464465

465466
raise AssertionError("This line should be unreachable.")

vllm/model_executor/models/terratorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def apply(
215215
dummy_inputs=TerratorchInputBuilder,
216216
)
217217
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
218+
merge_by_field_config = True
218219
supports_multimodal_raw_input_only = True
219220
is_pooling_model = True
220221

vllm/model_executor/models/transformers.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@
5959
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
6060
SupportsMultiModal, SupportsPP, SupportsQuant)
6161
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
62-
flatten_bn, make_empty_intermediate_tensors_factory,
63-
maybe_prefix)
62+
make_empty_intermediate_tensors_factory, maybe_prefix)
6463

6564
logger = init_logger(__name__)
6665

@@ -812,17 +811,6 @@ def compute_logits(
812811
return logits
813812

814813

815-
def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
816-
"""Flatten until a list of tensors can be concatenated then do concat"""
817-
818-
def _can_concat(x: list[torch.Tensor]):
819-
return len(set(map(lambda _x: _x.shape[1:], x))) == 1
820-
821-
if _can_concat(x):
822-
return torch.concat(x)
823-
return flatten_and_concat(flatten_bn(x))
824-
825-
826814
@MULTIMODAL_REGISTRY.register_processor(
827815
MultiModalProcessor,
828816
info=MultiModalProcessingInfo,

0 commit comments

Comments
 (0)