Skip to content

Commit 81fe190

Browse files
[Model] Use merge_by_field_config for MM models (U-Z)
Signed-off-by: Ayush Satyam <[email protected]>
1 parent 512b8af commit 81fe190

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

vllm/model_executor/models/ultravox.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
from .utils import (
4949
AutoWeightsLoader,
5050
WeightsMapper,
51-
flatten_bn,
5251
init_vllm_registered_model,
5352
maybe_prefix,
5453
)
@@ -421,6 +420,8 @@ def forward(
421420
dummy_inputs=UltravoxDummyInputsBuilder,
422421
)
423422
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
423+
merge_by_field_config = True
424+
424425
packed_modules_mapping = {
425426
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
426427
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -547,9 +548,8 @@ def _process_audio_input(
547548
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
548549
audio_features = pad_and_concat_to_dim3(audio_input["data"])
549550

550-
# [B1, B2] -> [B1+B2]
551-
audio_lens = flatten_bn(audio_input["lens"], concat=True)
552-
audio_token_len = flatten_bn(audio_input["token_len"], concat=True)
551+
audio_lens = audio_input["lens"]
552+
audio_token_len = audio_input["token_len"]
553553

554554
embeddings = self._audio_features_to_embeddings(audio_features, audio_lens)
555555

@@ -662,7 +662,8 @@ def pad_and_concat_to_dim3(
662662
if isinstance(features, torch.Tensor):
663663
if features.ndim > 3:
664664
# Flatten [B, N, 80, M] -> [B * N, 80, M]
665-
features = flatten_bn(features)
665+
batch_size = features.shape[0] * features.shape[1]
666+
features = features.view(batch_size, *features.shape[2:])
666667
return features
667668

668669
features = [pad_and_concat_to_dim3(f) for f in features]

vllm/model_executor/models/voxtral.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
)
6262

6363
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
64-
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
64+
from .utils import init_vllm_registered_model, maybe_prefix
6565

6666
logger = init_logger(__name__)
6767

@@ -337,6 +337,8 @@ def _get_data_parser(self) -> MultiModalDataParser:
337337
class VoxtralForConditionalGeneration(
338338
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
339339
):
340+
merge_by_field_config = True
341+
340342
supported_languages = ISO639_1_SUPPORTED_LANGS
341343

342344
packed_modules_mapping = {
@@ -445,7 +447,6 @@ def _parse_and_validate_audio_arrays(
445447
f"Incorrect type of audio_arrays. Got type: {type(audio_arrays)}"
446448
)
447449

448-
audio_arrays = flatten_bn(audio_arrays)
449450
if isinstance(audio_arrays, torch.Tensor):
450451
audio_arrays = list(audio_arrays.unbind(0))
451452
return audio_arrays

vllm/model_executor/models/whisper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,7 @@ def _get_prompt_updates(
781781
class WhisperForConditionalGeneration(
782782
nn.Module, SupportsTranscription, SupportsMultiModal
783783
):
784+
merge_by_field_config = True
784785
packed_modules_mapping = {
785786
"self_attn.qkv_proj": [
786787
"self_attn.q_proj",

0 commit comments

Comments
 (0)