Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,8 @@ def forward(
dummy_inputs=UltravoxDummyInputsBuilder,
)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
merge_by_field_config = True

packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
Expand Down Expand Up @@ -663,6 +665,7 @@ def pad_and_concat_to_dim3(
if features.ndim > 3:
# Flatten [B, N, 80, M] -> [B * N, 80, M]
features = flatten_bn(features)

return features

features = [pad_and_concat_to_dim3(f) for f in features]
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/voxtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
)

from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
from .utils import init_vllm_registered_model, maybe_prefix

logger = init_logger(__name__)

Expand Down Expand Up @@ -337,6 +337,8 @@ def _get_data_parser(self) -> MultiModalDataParser:
class VoxtralForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
):
merge_by_field_config = True

supported_languages = ISO639_1_SUPPORTED_LANGS

packed_modules_mapping = {
Expand Down Expand Up @@ -445,7 +447,6 @@ def _parse_and_validate_audio_arrays(
f"Incorrect type of audio_arrays. Got type: {type(audio_arrays)}"
)

audio_arrays = flatten_bn(audio_arrays)
if isinstance(audio_arrays, torch.Tensor):
audio_arrays = list(audio_arrays.unbind(0))
return audio_arrays
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ def _get_prompt_updates(
class WhisperForConditionalGeneration(
nn.Module, SupportsTranscription, SupportsMultiModal
):
merge_by_field_config = True
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
Expand Down