|
23 | 23 | # limitations under the License.
|
24 | 24 | """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
25 | 25 | from collections.abc import Iterable, Mapping, Sequence
|
26 |
| -from typing import Any, Optional, TypedDict, Union |
| 26 | +from typing import Any, Literal, Optional, TypedDict, Union |
27 | 27 |
|
28 | 28 | import torch
|
29 | 29 | import torch.nn as nn
|
|
36 | 36 | from vllm.config import VllmConfig
|
37 | 37 | from vllm.model_executor.sampling_metadata import SamplingMetadata
|
38 | 38 | from vllm.multimodal import MULTIMODAL_REGISTRY
|
39 |
| -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, |
| 39 | +from vllm.multimodal.inputs import (AudioItem, ModalityData, |
| 40 | + MultiModalDataDict, MultiModalFieldConfig, |
40 | 41 | MultiModalKwargsItems)
|
41 |
| -from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, |
| 42 | +from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems, |
| 43 | + ModalityDataItems, MultiModalDataItems, |
42 | 44 | MultiModalDataParser)
|
43 | 45 | from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
44 | 46 | BaseProcessingInfo, PromptReplacement,
|
|
52 | 54 |
|
53 | 55 |
|
54 | 56 | # # === Audio Inputs === #
|
55 |
| -class Qwen2AudioInputs(TypedDict): |
| 57 | +class Qwen2AudioFeatureInputs(TypedDict): |
| 58 | + type: Literal["audio_features"] |
56 | 59 | input_features: torch.Tensor
|
57 | 60 | """Shape: `(num_audios, num_mel_bins, 3000)`"""
|
58 | 61 |
|
59 | 62 | feature_attention_mask: torch.Tensor
|
60 | 63 | """Shape: `(num_audios, 3000)`"""
|
61 | 64 |
|
62 | 65 |
|
| 66 | +class Qwen2AudioEmbeddingInputs(TypedDict): |
| 67 | + type: Literal["audio_embeds"] |
| 68 | + audio_embeds: list[torch.Tensor] |
| 69 | + """Shape: `(num_audio_features, hidden_size)` |
| 70 | + `hidden_size` must match the hidden size of language model backbone. |
| 71 | + """ |
| 72 | + |
| 73 | + |
| 74 | +Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs] |
| 75 | + |
63 | 76 | # === Audio Encoder === #
|
64 | 77 |
|
65 | 78 |
|
@@ -128,12 +141,38 @@ def get_dummy_mm_data(
|
128 | 141 | }
|
129 | 142 |
|
130 | 143 |
|
| 144 | +def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]): |
| 145 | + return dict( |
| 146 | + audio_embeds=MultiModalFieldConfig.batched("audio"), |
| 147 | + input_features=MultiModalFieldConfig.batched("audio"), |
| 148 | + feature_attention_mask=MultiModalFieldConfig.batched("audio"), |
| 149 | + ) |
| 150 | + |
| 151 | + |
| 152 | +class Qwen2AudioMultiModalDataParser(MultiModalDataParser): |
| 153 | + |
| 154 | + def _parse_audio_data( |
| 155 | + self, |
| 156 | + data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]], |
| 157 | + ) -> Optional[ModalityDataItems[Any, Any]]: |
| 158 | + if isinstance(data, dict): |
| 159 | + return DictEmbeddingItems( |
| 160 | + data, |
| 161 | + modality="audio", |
| 162 | + required_fields={"audio_embeds"}, |
| 163 | + fields_factory=_qwen2audio_field_config, |
| 164 | + ) |
| 165 | + |
| 166 | + return super()._parse_audio_data(data) |
| 167 | + |
| 168 | + |
131 | 169 | class Qwen2AudioMultiModalProcessor(
|
132 | 170 | BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
|
133 | 171 |
|
134 | 172 | def _get_data_parser(self) -> MultiModalDataParser:
|
135 | 173 | feature_extractor = self.info.get_feature_extractor()
|
136 |
| - return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) |
| 174 | + return Qwen2AudioMultiModalDataParser( |
| 175 | + target_sr=feature_extractor.sampling_rate) |
137 | 176 |
|
138 | 177 | def _call_hf_processor(
|
139 | 178 | self,
|
@@ -173,17 +212,15 @@ def _get_mm_fields_config(
|
173 | 212 | hf_inputs: BatchFeature,
|
174 | 213 | hf_processor_mm_kwargs: Mapping[str, object],
|
175 | 214 | ) -> Mapping[str, MultiModalFieldConfig]:
|
176 |
| - return dict( |
177 |
| - input_features=MultiModalFieldConfig.batched("audio"), |
178 |
| - feature_attention_mask=MultiModalFieldConfig.batched("audio"), |
179 |
| - ) |
| 215 | + return _qwen2audio_field_config(hf_inputs) |
180 | 216 |
|
181 | 217 | def _get_prompt_updates(
|
182 | 218 | self,
|
183 | 219 | mm_items: MultiModalDataItems,
|
184 | 220 | hf_processor_mm_kwargs: Mapping[str, object],
|
185 | 221 | out_mm_kwargs: MultiModalKwargsItems,
|
186 | 222 | ) -> Sequence[PromptUpdate]:
|
| 223 | + |
187 | 224 | processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
188 | 225 | tokenizer = self.info.get_tokenizer()
|
189 | 226 | vocab = tokenizer.get_vocab()
|
@@ -211,7 +248,15 @@ def _get_prompt_updates(
|
211 | 248 | audio_output_lengths = audio_output_lens.tolist()
|
212 | 249 |
|
213 | 250 | def get_replacement_qwen2_audio(item_idx: int):
|
214 |
| - num_features = audio_output_lengths[item_idx] |
| 251 | + |
| 252 | + if audio_output_lengths: |
| 253 | + num_features = audio_output_lengths[item_idx] |
| 254 | + else: |
| 255 | + audio_embeds = out_mm_data["audio_embeds"][item_idx] |
| 256 | + assert len(audio_embeds.shape |
| 257 | + ) == 2, "audio_embeds must be a 2D tensor" |
| 258 | + num_features = audio_embeds.shape[0] |
| 259 | + |
215 | 260 | if num_features == 0:
|
216 | 261 | audios = mm_items.get_items("audio", AudioProcessorItems)
|
217 | 262 | audio_len = audios.get_audio_length(item_idx)
|
@@ -286,21 +331,39 @@ def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
286 | 331 | def _parse_and_validate_audio_input(
|
287 | 332 | self, **kwargs: object) -> Optional[Qwen2AudioInputs]:
|
288 | 333 | input_features = kwargs.pop('input_features', None)
|
| 334 | + audio_embeds = kwargs.pop('audio_embeds', None) |
289 | 335 | feature_attention_mask = kwargs.pop('feature_attention_mask', None)
|
290 |
| - if input_features is None: |
| 336 | + |
| 337 | + if input_features is None and audio_embeds is None: |
291 | 338 | return None
|
292 |
| - input_features = self._validate_and_reshape_mm_tensor( |
293 |
| - input_features, 'input_features') |
294 |
| - feature_attention_mask = self._validate_and_reshape_mm_tensor( |
295 |
| - feature_attention_mask, 'feature_attention_mask') |
296 |
| - if not isinstance(input_features, (torch.Tensor, list)): |
297 |
| - raise ValueError("Incorrect type of audio input features. " |
298 |
| - f"Got type: {type(input_features)}") |
299 |
| - return Qwen2AudioInputs(input_features=input_features, |
300 |
| - feature_attention_mask=feature_attention_mask) |
301 |
| - |
302 |
| - def _process_audio_input(self, |
303 |
| - audio_input: Qwen2AudioInputs) -> torch.Tensor: |
| 339 | + |
| 340 | + if audio_embeds is not None: |
| 341 | + if not isinstance(audio_embeds, (torch.Tensor, list)): |
| 342 | + raise ValueError("Incorrect type of audio embeds. " |
| 343 | + f"Got type: {type(audio_embeds)}") |
| 344 | + audio_embeds = self._validate_and_reshape_mm_tensor( |
| 345 | + audio_embeds, "audio_embeds") |
| 346 | + return Qwen2AudioEmbeddingInputs(type="audio_embeds", |
| 347 | + audio_embeds=audio_embeds) |
| 348 | + |
| 349 | + if input_features is not None: |
| 350 | + input_features = self._validate_and_reshape_mm_tensor( |
| 351 | + input_features, 'input_features') |
| 352 | + feature_attention_mask = self._validate_and_reshape_mm_tensor( |
| 353 | + feature_attention_mask, 'feature_attention_mask') |
| 354 | + return Qwen2AudioFeatureInputs( |
| 355 | + type="audio_features", |
| 356 | + input_features=input_features, |
| 357 | + feature_attention_mask=feature_attention_mask) |
| 358 | + |
| 359 | + raise AssertionError("This line should be unreachable.") |
| 360 | + |
| 361 | + def _process_audio_input( |
| 362 | + self, audio_input: Qwen2AudioInputs |
| 363 | + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: |
| 364 | + if audio_input["type"] == "audio_embeds": |
| 365 | + audio_embeds = audio_input["audio_embeds"] |
| 366 | + return tuple(audio_embeds) |
304 | 367 |
|
305 | 368 | input_features = audio_input["input_features"]
|
306 | 369 | feature_attention_mask = audio_input["feature_attention_mask"]
|
|
0 commit comments