|
48 | 48 | from .utils import (
|
49 | 49 | AutoWeightsLoader,
|
50 | 50 | WeightsMapper,
|
51 |
| - flatten_bn, |
52 | 51 | init_vllm_registered_model,
|
53 | 52 | maybe_prefix,
|
54 | 53 | )
|
@@ -421,6 +420,8 @@ def forward(
|
421 | 420 | dummy_inputs=UltravoxDummyInputsBuilder,
|
422 | 421 | )
|
423 | 422 | class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
| 423 | + merge_by_field_config = True |
| 424 | + |
424 | 425 | packed_modules_mapping = {
|
425 | 426 | "qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
426 | 427 | "gate_up_proj": ["gate_proj", "up_proj"],
|
@@ -547,9 +548,8 @@ def _process_audio_input(
|
547 | 548 | # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
|
548 | 549 | audio_features = pad_and_concat_to_dim3(audio_input["data"])
|
549 | 550 |
|
550 |
| - # [B1, B2] -> [B1+B2] |
551 |
| - audio_lens = flatten_bn(audio_input["lens"], concat=True) |
552 |
| - audio_token_len = flatten_bn(audio_input["token_len"], concat=True) |
| 551 | + audio_lens = audio_input["lens"] |
| 552 | + audio_token_len = audio_input["token_len"] |
553 | 553 |
|
554 | 554 | embeddings = self._audio_features_to_embeddings(audio_features, audio_lens)
|
555 | 555 |
|
@@ -662,7 +662,8 @@ def pad_and_concat_to_dim3(
|
662 | 662 | if isinstance(features, torch.Tensor):
|
663 | 663 | if features.ndim > 3:
|
664 | 664 | # Flatten [B, N, 80, M] -> [B * N, 80, M]
|
665 |
| - features = flatten_bn(features) |
| 665 | + batch_size = features.shape[0] * features.shape[1] |
| 666 | + features = features.view(batch_size, *features.shape[2:]) |
666 | 667 | return features
|
667 | 668 |
|
668 | 669 | features = [pad_and_concat_to_dim3(f) for f in features]
|
|
0 commit comments