22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33import math
44from collections .abc import Iterable , Mapping , Sequence
5- from typing import Any , Literal , Optional , TypedDict , Union
5+ from typing import Annotated , Any , Literal , Optional , Union
66
77import numpy as np
88import torch
4040from vllm .multimodal .profiling import BaseDummyInputsBuilder
4141from vllm .sequence import IntermediateTensors
4242from vllm .utils import is_list_of
43+ from vllm .utils .tensor_schema import TensorSchema , TensorShape
4344
4445from .idefics2_vision_model import Idefics2VisionTransformer
4546from .interfaces import MultiModalEmbeddings , SupportsLoRA , SupportsMultiModal
@@ -615,50 +616,90 @@ def load_weights(self, weights: Iterable[tuple[str,
615616 return loaded_params
616617
617618
618- class Phi4MMImagePixelInputs (TypedDict ):
619- type : Literal ["pixel_values" ]
620- data : Union [torch .Tensor , list [torch .Tensor ]]
619+ class Phi4MMImagePixelInputs (TensorSchema ):
621620 """
622- Shape:
623- `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
624-
625- Note that `num_patches` may be different per batch and image,
626- in which case the data is passed as a list instead of a batched tensor.
621+ Dimensions:
622+ - bn: Batch size * number of images
623+ - p: Number of patches (1 + num_patches)
624+ - c: Number of channels (3)
625+ - h: Height of each image patch
626+ - w: Width of each image patch
627+ - nc: Number of crops
628+ - H_mask: Height of attention mask
629+ - W_mask: Width of attention mask
627630 """
628631
629- image_sizes : torch .Tensor
630- """
631- Shape: `(batch_size * num_images, 2)`
632+ type : Literal ["pixel_values" ]
632633
633- This should be in `(height, width)` format.
634- """
634+ data : Annotated [
635+ Union [torch .Tensor , list [torch .Tensor ]],
636+ TensorShape ("bn" , "p" , 3 , "h" , "w" , dynamic_dims = {"p" }
637+ ), # may be different per batch and image
638+ ]
635639
636- num_img_tokens : list [int ]
637- """Shape: `(batch_size * num_images)`"""
640+ image_sizes : Annotated [
641+ torch .Tensor ,
642+ TensorShape ("bn" , 2 ), # (height, width)
643+ ]
638644
639- image_attention_mask : torch .Tensor
640- """Shape: `(batch_size * num_images, H_mask, W_mask)`"""
645+ num_img_tokens : Annotated [
646+ list [int ],
647+ TensorShape ("bn" ),
648+ ]
641649
650+ image_attention_mask : Annotated [
651+ torch .Tensor ,
652+ TensorShape ("bn" , "nc" , 32 , 32 ), # H_mask, W_mask
653+ ]
642654
643- class Phi4MMImageEmbeddingInputs (TypedDict ):
644- type : Literal ["image_embeds" ]
645- data : Union [torch .Tensor , list [torch .Tensor ]]
646- """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
647655
648- `hidden_size` must match the hidden size of language model backbone.
656+ class Phi4MMImageEmbeddingInputs ( TensorSchema ):
649657 """
658+ Dimensions:
659+ - bn: Batch size * number of images
660+ - f: Image feature size
661+ - h: Hidden size (must match language model backbone)
662+ """
663+
664+ type : Literal ["image_embeds" ]
665+
666+ data : Annotated [
667+ Union [torch .Tensor , list [torch .Tensor ]],
668+ TensorShape ("bn" , "f" , "h" ),
669+ ]
670+
650671
672+ class Phi4MMAudioFeatureInputs (TensorSchema ):
673+ """
674+ Dimensions:
675+ - bn: Batch size * number of audios
676+ - f: Number of Mel filterbank bins (80)
677+ - t: Time frames (M)
678+ """
651679
652- class Phi4MMAudioFeatureInputs (TypedDict ):
653680 type : Literal ["audio_features" ]
654- data : Union [torch .Tensor , list [torch .Tensor ]]
655- """Shape: `(batch_size * num_audios, 80, M)"""
656681
682+ data : Annotated [
683+ Union [torch .Tensor , list [torch .Tensor ]],
684+ TensorShape ("bn" , "t" , 80 , dynamic_dims = {"t" }),
685+ ]
686+
687+
688+ class Phi4MMAudioEmbeddingInputs (TensorSchema ):
689+ """
690+ Dimensions:
691+ - b: Batch size
692+ - n: Number of audios
693+ - f: Audio feature size
694+ - h: Hidden size (must match language model backbone)
695+ """
657696
658- class Phi4MMAudioEmbeddingInputs (TypedDict ):
659697 type : Literal ["audio_embeds" ]
660- data : NestedTensors
661- """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
698+
699+ data : Annotated [
700+ NestedTensors ,
701+ TensorShape ("b" , "n" , "f" , "h" ),
702+ ]
662703
663704
664705Phi4MMImageInput = Union [Phi4MMImagePixelInputs , Phi4MMImageEmbeddingInputs ]
@@ -1170,18 +1211,10 @@ def _parse_and_validate_audio_input(
11701211 return None
11711212
11721213 if audio_features is not None :
1173- if not isinstance (audio_features , (torch .Tensor , list )):
1174- raise ValueError ("Incorrect type of audio features. "
1175- f"Got type: { type (audio_features )} " )
1176-
11771214 return Phi4MMAudioFeatureInputs (type = "audio_features" ,
11781215 data = flatten_bn (audio_features ))
11791216
11801217 if audio_embeds is not None :
1181- if not isinstance (audio_embeds , (torch .Tensor , list )):
1182- raise ValueError ("Incorrect type of audio embeds. "
1183- f"Got type: { type (audio_embeds )} " )
1184-
11851218 return Phi4MMAudioEmbeddingInputs (type = "audio_embeds" ,
11861219 data = audio_embeds )
11871220
@@ -1259,7 +1292,7 @@ def _parse_and_validate_image_input(
12591292 elif isinstance (image_sizes , torch .Tensor ):
12601293 image_sizes = image_sizes .flatten (0 , 1 )
12611294 else :
1262- raise ValueError ("Incorrect image_attention_mask inputs" )
1295+ raise ValueError ("Incorrect image_sizes inputs" )
12631296
12641297 if isinstance (num_img_tokens , list ):
12651298 num_img_tokens = [
@@ -1269,7 +1302,7 @@ def _parse_and_validate_image_input(
12691302 elif isinstance (num_img_tokens , torch .Tensor ):
12701303 num_img_tokens = num_img_tokens .flatten (0 , 1 ).tolist ()
12711304 else :
1272- raise ValueError ("Incorrect image_attention_mask inputs" )
1305+ raise ValueError ("Incorrect num_img_tokens inputs" )
12731306
12741307 return Phi4MMImagePixelInputs (
12751308 type = "pixel_values" ,
0 commit comments