Skip to content

Commit 5f7e8a9

Browse files
[Model] Define merge_by_field_config MM interface (U-Z) (#26261)
Signed-off-by: Ayush Satyam <[email protected]> Signed-off-by: DarkLight1337 <[email protected]> Co-authored-by: DarkLight1337 <[email protected]>
1 parent 4dbdf4a commit 5f7e8a9

File tree

4 files changed

+32
-25
lines changed

4 files changed

+32
-25
lines changed

vllm/model_executor/models/ultravox.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -69,18 +69,16 @@ class UltravoxAudioFeatureInputs(TensorSchema):
6969
type: Literal["audio_features"]
7070
data: Annotated[
7171
Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]],
72-
TensorShape("b", "n", "nmb", "t", dynamic_dims={"n"}),
72+
TensorShape("bn", "nmb", "t"),
7373
]
74-
lens: Annotated[
75-
Union[torch.Tensor, list[torch.Tensor]],
76-
TensorShape("b", "n", dynamic_dims={"n"}),
77-
]
78-
"""Length of the audio frames. Used for attention mask in WhisperEncoder."""
79-
token_len: Annotated[
80-
Union[torch.Tensor, list[torch.Tensor]],
81-
TensorShape("b", "n", dynamic_dims={"n"}),
82-
]
83-
"""Length of the audio tokens. Used for flattening the audio features."""
74+
lens: Annotated[torch.Tensor, TensorShape("bn")]
75+
"""
76+
Length of the audio frames per chunk. Used for attention mask in WhisperEncoder.
77+
"""
78+
token_len: Annotated[torch.Tensor, TensorShape("bn")]
79+
"""Length of the audio tokens per chunk. Used for flattening the audio features."""
80+
num_chunks: Annotated[torch.Tensor, TensorShape("n")]
81+
"""Number of chunks per audio. Used for flattening the audio features."""
8482

8583

8684
class UltravoxAudioEmbeddingInputs(TensorSchema):
@@ -421,6 +419,8 @@ def forward(
421419
dummy_inputs=UltravoxDummyInputsBuilder,
422420
)
423421
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
422+
merge_by_field_config = True
423+
424424
packed_modules_mapping = {
425425
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
426426
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -519,6 +519,7 @@ def _parse_and_validate_audio_input(
519519
audio_embeds = kwargs.pop("audio_embeds", None)
520520
audio_lens = kwargs.pop("audio_lens", None)
521521
audio_token_len = kwargs.pop("audio_token_len", None)
522+
audio_num_chunks = kwargs.pop("audio_num_chunks", None)
522523

523524
if audio_features is None and audio_embeds is None:
524525
return None
@@ -529,6 +530,7 @@ def _parse_and_validate_audio_input(
529530
data=audio_features,
530531
lens=audio_lens,
531532
token_len=audio_token_len,
533+
num_chunks=audio_num_chunks,
532534
)
533535

534536
if audio_embeds is not None:
@@ -547,9 +549,8 @@ def _process_audio_input(
547549
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
548550
audio_features = pad_and_concat_to_dim3(audio_input["data"])
549551

550-
# [B1, B2] -> [B1+B2]
551-
audio_lens = flatten_bn(audio_input["lens"], concat=True)
552-
audio_token_len = flatten_bn(audio_input["token_len"], concat=True)
552+
audio_lens = audio_input["lens"]
553+
audio_token_len = audio_input["token_len"]
553554

554555
embeddings = self._audio_features_to_embeddings(audio_features, audio_lens)
555556

@@ -568,7 +569,8 @@ def _process_audio_input(
568569

569570
# Return one tensor per input audio
570571
embed_lens = [
571-
token_len_item.sum().item() for token_len_item in audio_input["token_len"]
572+
chunk_lens.sum().item()
573+
for chunk_lens in audio_token_len.split(audio_input["num_chunks"].tolist())
572574
]
573575
return flattened_embeddings.split(embed_lens)
574576

@@ -663,6 +665,7 @@ def pad_and_concat_to_dim3(
663665
if features.ndim > 3:
664666
# Flatten [B, N, 80, M] -> [B * N, 80, M]
665667
features = flatten_bn(features)
668+
666669
return features
667670

668671
features = [pad_and_concat_to_dim3(f) for f in features]

vllm/model_executor/models/voxtral.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
)
6262

6363
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
64-
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
64+
from .utils import init_vllm_registered_model, maybe_prefix
6565

6666
logger = init_logger(__name__)
6767

@@ -337,6 +337,8 @@ def _get_data_parser(self) -> MultiModalDataParser:
337337
class VoxtralForConditionalGeneration(
338338
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
339339
):
340+
merge_by_field_config = True
341+
340342
supported_languages = ISO639_1_SUPPORTED_LANGS
341343

342344
packed_modules_mapping = {
@@ -445,7 +447,6 @@ def _parse_and_validate_audio_arrays(
445447
f"Incorrect type of audio_arrays. Got type: {type(audio_arrays)}"
446448
)
447449

448-
audio_arrays = flatten_bn(audio_arrays)
449450
if isinstance(audio_arrays, torch.Tensor):
450451
audio_arrays = list(audio_arrays.unbind(0))
451452
return audio_arrays

vllm/model_executor/models/whisper.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
3737
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
3838
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39-
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
39+
from vllm.multimodal import MULTIMODAL_REGISTRY
4040
from vllm.multimodal.inputs import (
4141
MultiModalDataDict,
4242
MultiModalFieldConfig,
@@ -51,6 +51,7 @@
5151
)
5252
from vllm.multimodal.profiling import BaseDummyInputsBuilder
5353
from vllm.transformers_utils.processor import cached_get_processor
54+
from vllm.utils.jsontree import json_map_leaves
5455
from vllm.utils.tensor_schema import TensorSchema, TensorShape
5556

5657
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
@@ -135,7 +136,10 @@ class WhisperAudioInputs(TensorSchema):
135136
- t: Time frames (M)
136137
"""
137138

138-
input_features: Annotated[Optional[NestedTensors], TensorShape("b", "nmb", "t")]
139+
input_features: Annotated[
140+
Optional[list[torch.Tensor]],
141+
TensorShape("b", "nmb", "t"),
142+
]
139143

140144

141145
class WhisperEncoderAttention(MultiHeadAttention):
@@ -781,6 +785,7 @@ def _get_prompt_updates(
781785
class WhisperForConditionalGeneration(
782786
nn.Module, SupportsTranscription, SupportsMultiModal
783787
):
788+
merge_by_field_config = True
784789
packed_modules_mapping = {
785790
"self_attn.qkv_proj": [
786791
"self_attn.q_proj",
@@ -936,12 +941,7 @@ def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInput
936941
input_features = kwargs.pop("input_features", None)
937942

938943
if input_features is not None:
939-
if not isinstance(input_features, (torch.Tensor, list)):
940-
raise ValueError(
941-
"Incorrect type of audio features. "
942-
f"Got type: {type(input_features)}"
943-
)
944-
input_features = torch.cat([feat.to(self.dtype) for feat in input_features])
944+
input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features)
945945

946946
return WhisperAudioInputs(input_features=input_features)
947947

vllm/multimodal/inputs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,9 @@ def __init__(self, field: BaseMultiModalField, modality: str) -> None:
677677
self.field = field
678678
self.modality = modality
679679

680+
def __repr__(self) -> str:
681+
return f"MultiModalFieldConfig(field={self.field}, modality={self.modality})"
682+
680683
def build_elems(
681684
self,
682685
key: str,

0 commit comments

Comments
 (0)