5
5
import math
6
6
from collections .abc import Iterable , Mapping , Sequence
7
7
from functools import cached_property
8
- from typing import Any , List , Literal , Optional , Set , Tuple , TypedDict , Union
8
+ from typing import Any , Literal , Optional , Set , Tuple , TypedDict , Union
9
9
10
10
import torch
11
11
import torch .utils .checkpoint
36
36
from vllm .transformers_utils .configs .ultravox import UltravoxConfig
37
37
38
38
from .interfaces import (MultiModalEmbeddings , SupportsLoRA ,
39
- SupportsMultiModal , SupportsPP , SupportsV0Only )
39
+ SupportsMultiModal , SupportsPP )
40
40
from .utils import (AutoWeightsLoader , WeightsMapper , flatten_bn ,
41
41
init_vllm_registered_model , maybe_prefix ,
42
42
merge_multimodal_embeddings ,
50
50
51
51
class UltravoxAudioFeatureInputs (TypedDict ):
52
52
type : Literal ["audio_features" ]
53
- data : NestedTensors
53
+ data : Union [ torch . Tensor , list [ torch . Tensor ], list [ list [ torch . Tensor ]]]
54
54
"""Shape: `(batch_size, num_chunks, 80, M)`"""
55
- lens : NestedTensors
55
+ lens : Union [ torch . Tensor , list [ torch . Tensor ]]
56
56
"""
57
57
Length of the audio frames. Used for attention mask in WhisperEncoder.
58
58
Shape: `(batch_size, num_chunks)`
59
59
"""
60
- token_len : NestedTensors
60
+ token_len : Union [ torch . Tensor , list [ torch . Tensor ]]
61
61
"""
62
62
Length of the audio tokens. Used for flattening the audio features.
63
63
Shape: `(batch_size, num_chunks)`
@@ -405,8 +405,7 @@ def forward(
405
405
UltravoxMultiModalProcessor ,
406
406
info = UltravoxProcessingInfo ,
407
407
dummy_inputs = UltravoxDummyInputsBuilder )
408
- class UltravoxModel (nn .Module , SupportsMultiModal , SupportsPP , SupportsLoRA ,
409
- SupportsV0Only ):
408
+ class UltravoxModel (nn .Module , SupportsMultiModal , SupportsPP , SupportsLoRA ):
410
409
411
410
packed_modules_mapping = {
412
411
"qkv_proj" : ["q_proj" , "k_proj" , "v_proj" ],
@@ -506,6 +505,12 @@ def _parse_and_validate_audio_input(
506
505
if not isinstance (audio_features , (torch .Tensor , list )):
507
506
raise ValueError ("Incorrect type of audio features. "
508
507
f"Got type: { type (audio_features )} " )
508
+ if not isinstance (audio_lens , (torch .Tensor , list )):
509
+ raise ValueError ("Incorrect type of audio_lens. "
510
+ f"Got type: { type (audio_features )} " )
511
+ if not isinstance (audio_token_len , (torch .Tensor , list )):
512
+ raise ValueError ("Incorrect type of audio_token_len. "
513
+ f"Got type: { type (audio_features )} " )
509
514
510
515
return UltravoxAudioFeatureInputs (type = "audio_features" ,
511
516
data = audio_features ,
@@ -523,21 +528,19 @@ def _parse_and_validate_audio_input(
523
528
raise AssertionError ("This line should be unreachable." )
524
529
525
530
def _process_audio_input (
526
- self , audio_input : UltravoxAudioInputs ) -> NestedTensors :
531
+ self ,
532
+ audio_input : UltravoxAudioInputs ,
533
+ ) -> Union [NestedTensors , tuple [torch .Tensor , ...]]:
527
534
if audio_input ["type" ] == "audio_embeds" :
528
535
return audio_input ["data" ]
529
536
530
537
# Pad and concatenate audio features
531
538
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
532
539
audio_features = pad_and_concat_to_dim3 (audio_input ["data" ])
533
540
534
- if isinstance (audio_input ['lens' ], list ):
535
- # [B1, B2] -> [B1+B2]
536
- audio_lens = torch .cat (audio_input ['lens' ])
537
- audio_token_len = torch .cat (audio_input ['token_len' ])
538
- else :
539
- audio_lens = flatten_bn (audio_input ['lens' ])
540
- audio_token_len = flatten_bn (audio_input ['token_len' ])
541
+ # [B1, B2] -> [B1+B2]
542
+ audio_lens = flatten_bn (audio_input ['lens' ], concat = True )
543
+ audio_token_len = flatten_bn (audio_input ['token_len' ], concat = True )
541
544
542
545
embeddings = self ._audio_features_to_embeddings (
543
546
audio_features , audio_lens )
@@ -554,7 +557,12 @@ def _process_audio_input(
554
557
# Apply mask and flatten
555
558
flattened_embeddings = embeddings [mask ]
556
559
557
- return flattened_embeddings
560
+ # Return one tensor per input audio
561
+ embed_lens = [
562
+ token_len_item .sum ().item ()
563
+ for token_len_item in audio_input ['token_len' ]
564
+ ]
565
+ return flattened_embeddings .split (embed_lens )
558
566
559
567
def get_multimodal_embeddings (
560
568
self , ** kwargs : object ) -> Optional [MultiModalEmbeddings ]:
@@ -646,7 +654,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
646
654
647
655
648
656
def pad_and_concat_to_dim3 (
649
- features : Union [torch .Tensor , List [torch .Tensor ], List [ List [torch .Tensor ]]]
657
+ features : Union [torch .Tensor , list [torch .Tensor ], list [ list [torch .Tensor ]]]
650
658
) -> torch .Tensor :
651
659
"""
652
660
Pad and concatenate a list of tensors.
0 commit comments