Skip to content

Commit 868a8c5

Browse files
[Bugfix] Fix Ultravox on V1 (#14929)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent b4ad56c commit 868a8c5

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

vllm/model_executor/models/ultravox.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import math
66
from collections.abc import Iterable, Mapping, Sequence
77
from functools import cached_property
8-
from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union
8+
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
99

1010
import torch
1111
import torch.utils.checkpoint
@@ -36,7 +36,7 @@
3636
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
3737

3838
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
39-
SupportsMultiModal, SupportsPP, SupportsV0Only)
39+
SupportsMultiModal, SupportsPP)
4040
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
4141
init_vllm_registered_model, maybe_prefix,
4242
merge_multimodal_embeddings,
@@ -50,14 +50,14 @@
5050

5151
class UltravoxAudioFeatureInputs(TypedDict):
5252
type: Literal["audio_features"]
53-
data: NestedTensors
53+
data: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]]
5454
"""Shape: `(batch_size, num_chunks, 80, M)`"""
55-
lens: NestedTensors
55+
lens: Union[torch.Tensor, list[torch.Tensor]]
5656
"""
5757
Length of the audio frames. Used for attention mask in WhisperEncoder.
5858
Shape: `(batch_size, num_chunks)`
5959
"""
60-
token_len: NestedTensors
60+
token_len: Union[torch.Tensor, list[torch.Tensor]]
6161
"""
6262
Length of the audio tokens. Used for flattening the audio features.
6363
Shape: `(batch_size, num_chunks)`
@@ -405,8 +405,7 @@ def forward(
405405
UltravoxMultiModalProcessor,
406406
info=UltravoxProcessingInfo,
407407
dummy_inputs=UltravoxDummyInputsBuilder)
408-
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
409-
SupportsV0Only):
408+
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
410409

411410
packed_modules_mapping = {
412411
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
@@ -506,6 +505,12 @@ def _parse_and_validate_audio_input(
506505
if not isinstance(audio_features, (torch.Tensor, list)):
507506
raise ValueError("Incorrect type of audio features. "
508507
f"Got type: {type(audio_features)}")
508+
if not isinstance(audio_lens, (torch.Tensor, list)):
509+
raise ValueError("Incorrect type of audio_lens. "
510+
f"Got type: {type(audio_features)}")
511+
if not isinstance(audio_token_len, (torch.Tensor, list)):
512+
raise ValueError("Incorrect type of audio_token_len. "
513+
f"Got type: {type(audio_features)}")
509514

510515
return UltravoxAudioFeatureInputs(type="audio_features",
511516
data=audio_features,
@@ -523,21 +528,19 @@ def _parse_and_validate_audio_input(
523528
raise AssertionError("This line should be unreachable.")
524529

525530
def _process_audio_input(
526-
self, audio_input: UltravoxAudioInputs) -> NestedTensors:
531+
self,
532+
audio_input: UltravoxAudioInputs,
533+
) -> Union[NestedTensors, tuple[torch.Tensor, ...]]:
527534
if audio_input["type"] == "audio_embeds":
528535
return audio_input["data"]
529536

530537
# Pad and concatenate audio features
531538
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
532539
audio_features = pad_and_concat_to_dim3(audio_input["data"])
533540

534-
if isinstance(audio_input['lens'], list):
535-
# [B1, B2] -> [B1+B2]
536-
audio_lens = torch.cat(audio_input['lens'])
537-
audio_token_len = torch.cat(audio_input['token_len'])
538-
else:
539-
audio_lens = flatten_bn(audio_input['lens'])
540-
audio_token_len = flatten_bn(audio_input['token_len'])
541+
# [B1, B2] -> [B1+B2]
542+
audio_lens = flatten_bn(audio_input['lens'], concat=True)
543+
audio_token_len = flatten_bn(audio_input['token_len'], concat=True)
541544

542545
embeddings = self._audio_features_to_embeddings(
543546
audio_features, audio_lens)
@@ -554,7 +557,12 @@ def _process_audio_input(
554557
# Apply mask and flatten
555558
flattened_embeddings = embeddings[mask]
556559

557-
return flattened_embeddings
560+
# Return one tensor per input audio
561+
embed_lens = [
562+
token_len_item.sum().item()
563+
for token_len_item in audio_input['token_len']
564+
]
565+
return flattened_embeddings.split(embed_lens)
558566

559567
def get_multimodal_embeddings(
560568
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
@@ -646,7 +654,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
646654

647655

648656
def pad_and_concat_to_dim3(
649-
features: Union[torch.Tensor, List[torch.Tensor], List[List[torch.Tensor]]]
657+
features: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]]
650658
) -> torch.Tensor:
651659
"""
652660
Pad and concatenate a list of tensors.

0 commit comments

Comments
 (0)