Skip to content

Commit a7615f9

Browse files
refactor(multimodal U-Z): Migrate MM models to merge_by_field_config
Migrate step3_vl, tarsier, terratorch, ultravox, voxtral, and whisper to use merge_by_field_config = True, enabling HF-compatible input shapes. Remove flatten_bn calls and dead flatten_and_concat function. Signed-off-by: Ayush Satyam <[email protected]>
1 parent 512b8af commit a7615f9

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

vllm/model_executor/models/ultravox.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
from .utils import (
4949
AutoWeightsLoader,
5050
WeightsMapper,
51-
flatten_bn,
5251
init_vllm_registered_model,
5352
maybe_prefix,
5453
)
@@ -421,6 +420,8 @@ def forward(
421420
dummy_inputs=UltravoxDummyInputsBuilder,
422421
)
423422
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
423+
merge_by_field_config = True
424+
424425
packed_modules_mapping = {
425426
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
426427
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -547,9 +548,9 @@ def _process_audio_input(
547548
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
548549
audio_features = pad_and_concat_to_dim3(audio_input["data"])
549550

550-
# [B1, B2] -> [B1+B2]
551-
audio_lens = flatten_bn(audio_input["lens"], concat=True)
552-
audio_token_len = flatten_bn(audio_input["token_len"], concat=True)
551+
# Audio lens and token_len are already in the correct shape
552+
audio_lens = audio_input["lens"]
553+
audio_token_len = audio_input["token_len"]
553554

554555
embeddings = self._audio_features_to_embeddings(audio_features, audio_lens)
555556

@@ -662,7 +663,8 @@ def pad_and_concat_to_dim3(
662663
if isinstance(features, torch.Tensor):
663664
if features.ndim > 3:
664665
# Flatten [B, N, 80, M] -> [B * N, 80, M]
665-
features = flatten_bn(features)
666+
batch_size = features.shape[0] * features.shape[1]
667+
features = features.view(batch_size, *features.shape[2:])
666668
return features
667669

668670
features = [pad_and_concat_to_dim3(f) for f in features]

vllm/model_executor/models/voxtral.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
)
6262

6363
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
64-
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
64+
from .utils import init_vllm_registered_model, maybe_prefix
6565

6666
logger = init_logger(__name__)
6767

@@ -337,6 +337,8 @@ def _get_data_parser(self) -> MultiModalDataParser:
337337
class VoxtralForConditionalGeneration(
338338
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
339339
):
340+
merge_by_field_config = True
341+
340342
supported_languages = ISO639_1_SUPPORTED_LANGS
341343

342344
packed_modules_mapping = {
@@ -445,7 +447,6 @@ def _parse_and_validate_audio_arrays(
445447
f"Incorrect type of audio_arrays. Got type: {type(audio_arrays)}"
446448
)
447449

448-
audio_arrays = flatten_bn(audio_arrays)
449450
if isinstance(audio_arrays, torch.Tensor):
450451
audio_arrays = list(audio_arrays.unbind(0))
451452
return audio_arrays

vllm/model_executor/models/whisper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,7 @@ def _get_prompt_updates(
781781
class WhisperForConditionalGeneration(
782782
nn.Module, SupportsTranscription, SupportsMultiModal
783783
):
784+
merge_by_field_config = True
784785
packed_modules_mapping = {
785786
"self_attn.qkv_proj": [
786787
"self_attn.q_proj",

0 commit comments

Comments
 (0)