Skip to content

Commit 9d4183d

Browse files
[model] support qwen2audio embedding input (#23625)
Signed-off-by: Yuekai Zhang <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 513298f commit 9d4183d

File tree

2 files changed

+93
-29
lines changed

2 files changed

+93
-29
lines changed

vllm/model_executor/models/qwen2_5_omni_thinker.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs,
4848
Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs)
4949
from vllm.model_executor.models.qwen2_audio import (
50-
Qwen2AudioInputs, Qwen2AudioProcessingInfo,
50+
Qwen2AudioFeatureInputs, Qwen2AudioProcessingInfo,
5151
_get_feat_extract_output_lengths)
5252
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser
5353
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -534,7 +534,7 @@ def _validate_and_reshape_mm_tensor(self,
534534
return torch.concat(mm_input, dim=dim)
535535

536536
def _parse_and_validate_audio_input(
537-
self, **kwargs: object) -> Optional[Qwen2AudioInputs]:
537+
self, **kwargs: object) -> Optional[Qwen2AudioFeatureInputs]:
538538
input_audio_features = kwargs.pop('input_audio_features', None)
539539
audio_feature_lengths = kwargs.pop('audio_feature_lengths', None)
540540
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
@@ -548,9 +548,10 @@ def _parse_and_validate_audio_input(
548548
if not isinstance(input_audio_features, (torch.Tensor, list)):
549549
raise ValueError("Incorrect type of audio input features. "
550550
f"Got type: {type(input_audio_features)}")
551-
return Qwen2AudioInputs(input_features=input_audio_features,
552-
audio_feature_lengths=audio_feature_lengths,
553-
feature_attention_mask=feature_attention_mask)
551+
return Qwen2AudioFeatureInputs(
552+
input_features=input_audio_features,
553+
audio_feature_lengths=audio_feature_lengths,
554+
feature_attention_mask=feature_attention_mask)
554555

555556
def _parse_and_validate_image_input(
556557
self,
@@ -630,7 +631,7 @@ def _parse_and_validate_video_input(
630631

631632
def _process_audio_input(
632633
self,
633-
audio_input: Qwen2AudioInputs,
634+
audio_input: Qwen2AudioFeatureInputs,
634635
audio_hashes: list[str] = None,
635636
cached_audio_features: torch.Tensor = None,
636637
) -> torch.Tensor:

vllm/model_executor/models/qwen2_audio.py

Lines changed: 86 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
# limitations under the License.
2424
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
2525
from collections.abc import Iterable, Mapping, Sequence
26-
from typing import Any, Optional, TypedDict, Union
26+
from typing import Any, Literal, Optional, TypedDict, Union
2727

2828
import torch
2929
import torch.nn as nn
@@ -36,9 +36,11 @@
3636
from vllm.config import VllmConfig
3737
from vllm.model_executor.sampling_metadata import SamplingMetadata
3838
from vllm.multimodal import MULTIMODAL_REGISTRY
39-
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
39+
from vllm.multimodal.inputs import (AudioItem, ModalityData,
40+
MultiModalDataDict, MultiModalFieldConfig,
4041
MultiModalKwargsItems)
41-
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
42+
from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
43+
ModalityDataItems, MultiModalDataItems,
4244
MultiModalDataParser)
4345
from vllm.multimodal.processing import (BaseMultiModalProcessor,
4446
BaseProcessingInfo, PromptReplacement,
@@ -52,14 +54,25 @@
5254

5355

5456
# # === Audio Inputs === #
55-
class Qwen2AudioInputs(TypedDict):
57+
class Qwen2AudioFeatureInputs(TypedDict):
58+
type: Literal["audio_features"]
5659
input_features: torch.Tensor
5760
"""Shape: `(num_audios, num_mel_bins, 3000)`"""
5861

5962
feature_attention_mask: torch.Tensor
6063
"""Shape: `(num_audios, 3000)`"""
6164

6265

66+
class Qwen2AudioEmbeddingInputs(TypedDict):
67+
type: Literal["audio_embeds"]
68+
audio_embeds: list[torch.Tensor]
69+
"""Shape: `(num_audio_features, hidden_size)`
70+
`hidden_size` must match the hidden size of language model backbone.
71+
"""
72+
73+
74+
Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs]
75+
6376
# === Audio Encoder === #
6477

6578

@@ -128,12 +141,38 @@ def get_dummy_mm_data(
128141
}
129142

130143

144+
def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]):
145+
return dict(
146+
audio_embeds=MultiModalFieldConfig.batched("audio"),
147+
input_features=MultiModalFieldConfig.batched("audio"),
148+
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
149+
)
150+
151+
152+
class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
153+
154+
def _parse_audio_data(
155+
self,
156+
data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
157+
) -> Optional[ModalityDataItems[Any, Any]]:
158+
if isinstance(data, dict):
159+
return DictEmbeddingItems(
160+
data,
161+
modality="audio",
162+
required_fields={"audio_embeds"},
163+
fields_factory=_qwen2audio_field_config,
164+
)
165+
166+
return super()._parse_audio_data(data)
167+
168+
131169
class Qwen2AudioMultiModalProcessor(
132170
BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
133171

134172
def _get_data_parser(self) -> MultiModalDataParser:
135173
feature_extractor = self.info.get_feature_extractor()
136-
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
174+
return Qwen2AudioMultiModalDataParser(
175+
target_sr=feature_extractor.sampling_rate)
137176

138177
def _call_hf_processor(
139178
self,
@@ -173,17 +212,15 @@ def _get_mm_fields_config(
173212
hf_inputs: BatchFeature,
174213
hf_processor_mm_kwargs: Mapping[str, object],
175214
) -> Mapping[str, MultiModalFieldConfig]:
176-
return dict(
177-
input_features=MultiModalFieldConfig.batched("audio"),
178-
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
179-
)
215+
return _qwen2audio_field_config(hf_inputs)
180216

181217
def _get_prompt_updates(
182218
self,
183219
mm_items: MultiModalDataItems,
184220
hf_processor_mm_kwargs: Mapping[str, object],
185221
out_mm_kwargs: MultiModalKwargsItems,
186222
) -> Sequence[PromptUpdate]:
223+
187224
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
188225
tokenizer = self.info.get_tokenizer()
189226
vocab = tokenizer.get_vocab()
@@ -211,7 +248,15 @@ def _get_prompt_updates(
211248
audio_output_lengths = audio_output_lens.tolist()
212249

213250
def get_replacement_qwen2_audio(item_idx: int):
214-
num_features = audio_output_lengths[item_idx]
251+
252+
if audio_output_lengths:
253+
num_features = audio_output_lengths[item_idx]
254+
else:
255+
audio_embeds = out_mm_data["audio_embeds"][item_idx]
256+
assert len(audio_embeds.shape
257+
) == 2, "audio_embeds must be a 2D tensor"
258+
num_features = audio_embeds.shape[0]
259+
215260
if num_features == 0:
216261
audios = mm_items.get_items("audio", AudioProcessorItems)
217262
audio_len = audios.get_audio_length(item_idx)
@@ -286,21 +331,39 @@ def _validate_and_reshape_mm_tensor(self, mm_input: object,
286331
def _parse_and_validate_audio_input(
287332
self, **kwargs: object) -> Optional[Qwen2AudioInputs]:
288333
input_features = kwargs.pop('input_features', None)
334+
audio_embeds = kwargs.pop('audio_embeds', None)
289335
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
290-
if input_features is None:
336+
337+
if input_features is None and audio_embeds is None:
291338
return None
292-
input_features = self._validate_and_reshape_mm_tensor(
293-
input_features, 'input_features')
294-
feature_attention_mask = self._validate_and_reshape_mm_tensor(
295-
feature_attention_mask, 'feature_attention_mask')
296-
if not isinstance(input_features, (torch.Tensor, list)):
297-
raise ValueError("Incorrect type of audio input features. "
298-
f"Got type: {type(input_features)}")
299-
return Qwen2AudioInputs(input_features=input_features,
300-
feature_attention_mask=feature_attention_mask)
301-
302-
def _process_audio_input(self,
303-
audio_input: Qwen2AudioInputs) -> torch.Tensor:
339+
340+
if audio_embeds is not None:
341+
if not isinstance(audio_embeds, (torch.Tensor, list)):
342+
raise ValueError("Incorrect type of audio embeds. "
343+
f"Got type: {type(audio_embeds)}")
344+
audio_embeds = self._validate_and_reshape_mm_tensor(
345+
audio_embeds, "audio_embeds")
346+
return Qwen2AudioEmbeddingInputs(type="audio_embeds",
347+
audio_embeds=audio_embeds)
348+
349+
if input_features is not None:
350+
input_features = self._validate_and_reshape_mm_tensor(
351+
input_features, 'input_features')
352+
feature_attention_mask = self._validate_and_reshape_mm_tensor(
353+
feature_attention_mask, 'feature_attention_mask')
354+
return Qwen2AudioFeatureInputs(
355+
type="audio_features",
356+
input_features=input_features,
357+
feature_attention_mask=feature_attention_mask)
358+
359+
raise AssertionError("This line should be unreachable.")
360+
361+
def _process_audio_input(
362+
self, audio_input: Qwen2AudioInputs
363+
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
364+
if audio_input["type"] == "audio_embeds":
365+
audio_embeds = audio_input["audio_embeds"]
366+
return tuple(audio_embeds)
304367

305368
input_features = audio_input["input_features"]
306369
feature_attention_mask = audio_input["feature_attention_mask"]

0 commit comments

Comments
 (0)