@@ -69,18 +69,16 @@ class UltravoxAudioFeatureInputs(TensorSchema):
69
69
type : Literal ["audio_features" ]
70
70
data : Annotated [
71
71
Union [torch .Tensor , list [torch .Tensor ], list [list [torch .Tensor ]]],
72
- TensorShape ("b " , "n" , " nmb" , "t" , dynamic_dims = { "n" } ),
72
+ TensorShape ("bn " , "nmb" , "t" ),
73
73
]
74
- lens : Annotated [
75
- Union [torch .Tensor , list [torch .Tensor ]],
76
- TensorShape ("b" , "n" , dynamic_dims = {"n" }),
77
- ]
78
- """Length of the audio frames. Used for attention mask in WhisperEncoder."""
79
- token_len : Annotated [
80
- Union [torch .Tensor , list [torch .Tensor ]],
81
- TensorShape ("b" , "n" , dynamic_dims = {"n" }),
82
- ]
83
- """Length of the audio tokens. Used for flattening the audio features."""
74
+ lens : Annotated [torch .Tensor , TensorShape ("bn" )]
75
+ """
76
+ Length of the audio frames per chunk. Used for attention mask in WhisperEncoder.
77
+ """
78
+ token_len : Annotated [torch .Tensor , TensorShape ("bn" )]
79
+ """Length of the audio tokens per chunk. Used for flattening the audio features."""
80
+ num_chunks : Annotated [torch .Tensor , TensorShape ("n" )]
81
+ """Number of chunks per audio. Used for flattening the audio features."""
84
82
85
83
86
84
class UltravoxAudioEmbeddingInputs (TensorSchema ):
@@ -421,6 +419,8 @@ def forward(
421
419
dummy_inputs = UltravoxDummyInputsBuilder ,
422
420
)
423
421
class UltravoxModel (nn .Module , SupportsMultiModal , SupportsPP , SupportsLoRA ):
422
+ merge_by_field_config = True
423
+
424
424
packed_modules_mapping = {
425
425
"qkv_proj" : ["q_proj" , "k_proj" , "v_proj" ],
426
426
"gate_up_proj" : ["gate_proj" , "up_proj" ],
@@ -519,6 +519,7 @@ def _parse_and_validate_audio_input(
519
519
audio_embeds = kwargs .pop ("audio_embeds" , None )
520
520
audio_lens = kwargs .pop ("audio_lens" , None )
521
521
audio_token_len = kwargs .pop ("audio_token_len" , None )
522
+ audio_num_chunks = kwargs .pop ("audio_num_chunks" , None )
522
523
523
524
if audio_features is None and audio_embeds is None :
524
525
return None
@@ -529,6 +530,7 @@ def _parse_and_validate_audio_input(
529
530
data = audio_features ,
530
531
lens = audio_lens ,
531
532
token_len = audio_token_len ,
533
+ num_chunks = audio_num_chunks ,
532
534
)
533
535
534
536
if audio_embeds is not None :
@@ -547,9 +549,8 @@ def _process_audio_input(
547
549
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
548
550
audio_features = pad_and_concat_to_dim3 (audio_input ["data" ])
549
551
550
- # [B1, B2] -> [B1+B2]
551
- audio_lens = flatten_bn (audio_input ["lens" ], concat = True )
552
- audio_token_len = flatten_bn (audio_input ["token_len" ], concat = True )
552
+ audio_lens = audio_input ["lens" ]
553
+ audio_token_len = audio_input ["token_len" ]
553
554
554
555
embeddings = self ._audio_features_to_embeddings (audio_features , audio_lens )
555
556
@@ -568,7 +569,8 @@ def _process_audio_input(
568
569
569
570
# Return one tensor per input audio
570
571
embed_lens = [
571
- token_len_item .sum ().item () for token_len_item in audio_input ["token_len" ]
572
+ chunk_lens .sum ().item ()
573
+ for chunk_lens in audio_token_len .split (audio_input ["num_chunks" ].tolist ())
572
574
]
573
575
return flattened_embeddings .split (embed_lens )
574
576
@@ -663,6 +665,7 @@ def pad_and_concat_to_dim3(
663
665
if features .ndim > 3 :
664
666
# Flatten [B, N, 80, M] -> [B * N, 80, M]
665
667
features = flatten_bn (features )
668
+
666
669
return features
667
670
668
671
features = [pad_and_concat_to_dim3 (f ) for f in features ]
0 commit comments