Skip to content

Commit 437c3ce

Browse files
authored
Migrate Phi4 inputs to TensorSchema (#23471)
Signed-off-by: Benji Beck <[email protected]>
1 parent 499b074 commit 437c3ce

File tree

2 files changed

+129
-73
lines changed

2 files changed

+129
-73
lines changed

vllm/model_executor/models/phi4_multimodal.py

Lines changed: 72 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import math
44
from collections.abc import Iterable, Mapping, Sequence
5-
from typing import Any, Literal, Optional, TypedDict, Union
5+
from typing import Annotated, Any, Literal, Optional, Union
66

77
import numpy as np
88
import torch
@@ -40,6 +40,7 @@
4040
from vllm.multimodal.profiling import BaseDummyInputsBuilder
4141
from vllm.sequence import IntermediateTensors
4242
from vllm.utils import is_list_of
43+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
4344

4445
from .idefics2_vision_model import Idefics2VisionTransformer
4546
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
@@ -615,50 +616,90 @@ def load_weights(self, weights: Iterable[tuple[str,
615616
return loaded_params
616617

617618

618-
class Phi4MMImagePixelInputs(TypedDict):
619-
type: Literal["pixel_values"]
620-
data: Union[torch.Tensor, list[torch.Tensor]]
619+
class Phi4MMImagePixelInputs(TensorSchema):
621620
"""
622-
Shape:
623-
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
624-
625-
Note that `num_patches` may be different per batch and image,
626-
in which case the data is passed as a list instead of a batched tensor.
621+
Dimensions:
622+
- bn: Batch size * number of images
623+
- p: Number of patches (1 + num_patches)
624+
- c: Number of channels (3)
625+
- h: Height of each image patch
626+
- w: Width of each image patch
627+
- nc: Number of crops
628+
- H_mask: Height of attention mask
629+
- W_mask: Width of attention mask
627630
"""
628631

629-
image_sizes: torch.Tensor
630-
"""
631-
Shape: `(batch_size * num_images, 2)`
632+
type: Literal["pixel_values"]
632633

633-
This should be in `(height, width)` format.
634-
"""
634+
data: Annotated[
635+
Union[torch.Tensor, list[torch.Tensor]],
636+
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"}
637+
), # may be different per batch and image
638+
]
635639

636-
num_img_tokens: list[int]
637-
"""Shape: `(batch_size * num_images)`"""
640+
image_sizes: Annotated[
641+
torch.Tensor,
642+
TensorShape("bn", 2), # (height, width)
643+
]
638644

639-
image_attention_mask: torch.Tensor
640-
"""Shape: `(batch_size * num_images, H_mask, W_mask)`"""
645+
num_img_tokens: Annotated[
646+
list[int],
647+
TensorShape("bn"),
648+
]
641649

650+
image_attention_mask: Annotated[
651+
torch.Tensor,
652+
TensorShape("bn", "nc", 32, 32), # H_mask, W_mask
653+
]
642654

643-
class Phi4MMImageEmbeddingInputs(TypedDict):
644-
type: Literal["image_embeds"]
645-
data: Union[torch.Tensor, list[torch.Tensor]]
646-
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
647655

648-
`hidden_size` must match the hidden size of language model backbone.
656+
class Phi4MMImageEmbeddingInputs(TensorSchema):
649657
"""
658+
Dimensions:
659+
- bn: Batch size * number of images
660+
- f: Image feature size
661+
- h: Hidden size (must match language model backbone)
662+
"""
663+
664+
type: Literal["image_embeds"]
665+
666+
data: Annotated[
667+
Union[torch.Tensor, list[torch.Tensor]],
668+
TensorShape("bn", "f", "h"),
669+
]
670+
650671

672+
class Phi4MMAudioFeatureInputs(TensorSchema):
673+
"""
674+
Dimensions:
675+
- bn: Batch size * number of audios
676+
- f: Number of Mel filterbank bins (80)
677+
- t: Time frames (M)
678+
"""
651679

652-
class Phi4MMAudioFeatureInputs(TypedDict):
653680
type: Literal["audio_features"]
654-
data: Union[torch.Tensor, list[torch.Tensor]]
655-
"""Shape: `(batch_size * num_audios, 80, M)"""
656681

682+
data: Annotated[
683+
Union[torch.Tensor, list[torch.Tensor]],
684+
TensorShape("bn", "t", 80, dynamic_dims={"t"}),
685+
]
686+
687+
688+
class Phi4MMAudioEmbeddingInputs(TensorSchema):
689+
"""
690+
Dimensions:
691+
- b: Batch size
692+
- n: Number of audios
693+
- f: Audio feature size
694+
- h: Hidden size (must match language model backbone)
695+
"""
657696

658-
class Phi4MMAudioEmbeddingInputs(TypedDict):
659697
type: Literal["audio_embeds"]
660-
data: NestedTensors
661-
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
698+
699+
data: Annotated[
700+
NestedTensors,
701+
TensorShape("b", "n", "f", "h"),
702+
]
662703

663704

664705
Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs]
@@ -1170,18 +1211,10 @@ def _parse_and_validate_audio_input(
11701211
return None
11711212

11721213
if audio_features is not None:
1173-
if not isinstance(audio_features, (torch.Tensor, list)):
1174-
raise ValueError("Incorrect type of audio features. "
1175-
f"Got type: {type(audio_features)}")
1176-
11771214
return Phi4MMAudioFeatureInputs(type="audio_features",
11781215
data=flatten_bn(audio_features))
11791216

11801217
if audio_embeds is not None:
1181-
if not isinstance(audio_embeds, (torch.Tensor, list)):
1182-
raise ValueError("Incorrect type of audio embeds. "
1183-
f"Got type: {type(audio_embeds)}")
1184-
11851218
return Phi4MMAudioEmbeddingInputs(type="audio_embeds",
11861219
data=audio_embeds)
11871220

@@ -1259,7 +1292,7 @@ def _parse_and_validate_image_input(
12591292
elif isinstance(image_sizes, torch.Tensor):
12601293
image_sizes = image_sizes.flatten(0, 1)
12611294
else:
1262-
raise ValueError("Incorrect image_attention_mask inputs")
1295+
raise ValueError("Incorrect image_sizes inputs")
12631296

12641297
if isinstance(num_img_tokens, list):
12651298
num_img_tokens = [
@@ -1269,7 +1302,7 @@ def _parse_and_validate_image_input(
12691302
elif isinstance(num_img_tokens, torch.Tensor):
12701303
num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
12711304
else:
1272-
raise ValueError("Incorrect image_attention_mask inputs")
1305+
raise ValueError("Incorrect num_img_tokens inputs")
12731306

12741307
return Phi4MMImagePixelInputs(
12751308
type="pixel_values",

vllm/model_executor/models/phi4mm.py

Lines changed: 57 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import math
44
from collections.abc import Iterable, Mapping, Sequence
5-
from typing import Any, Literal, Optional, TypedDict, Union
5+
from typing import Annotated, Any, Literal, Optional, Union
66

77
import numpy as np
88
import torch
@@ -31,6 +31,7 @@
3131
from vllm.multimodal.profiling import BaseDummyInputsBuilder
3232
from vllm.sequence import IntermediateTensors
3333
from vllm.utils import is_list_of
34+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
3435

3536
from .idefics2_vision_model import Idefics2VisionTransformer
3637
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
@@ -391,41 +392,71 @@ def forward(self, pixel_values: torch.FloatTensor,
391392
return img_set_tensor
392393

393394

394-
class Phi4MMImagePixelInputs(TypedDict):
395-
type: Literal["pixel_values"]
396-
data: Union[torch.Tensor, list[torch.Tensor]]
395+
class Phi4MMImagePixelInputs(TensorSchema):
397396
"""
398-
Shape:
399-
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
400-
401-
Note that `num_patches` may be different per batch and image,
402-
in which case the data is passed as a list instead of a batched tensor.
397+
Dimensions:
398+
- bn: Batch size * number of images
399+
- p: Number of patches (1 + num_patches)
400+
- c: Number of channels (3)
401+
- h: Height of each image patch
402+
- w: Width of each image patch
403+
- nc: Number of crops
404+
- H_mask: Height of attention mask
405+
- W_mask: Width of attention mask
403406
"""
404407

405-
image_sizes: torch.Tensor
406-
"""
407-
Shape: `(batch_size * num_images, 2)`
408+
type: Literal["pixel_values"]
408409

409-
This should be in `(height, width)` format.
410-
"""
410+
data: Annotated[
411+
Union[torch.Tensor, list[torch.Tensor]],
412+
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"}
413+
), # may be different per batch and image
414+
]
415+
416+
image_sizes: Annotated[
417+
torch.Tensor,
418+
TensorShape("bn", 2), # (height, width)
419+
]
411420

412-
num_img_tokens: list[int]
413-
"""Shape: `(batch_size * num_images)`"""
421+
num_img_tokens: Annotated[
422+
list[int],
423+
TensorShape("bn"),
424+
]
414425

415-
image_attention_mask: torch.Tensor
416-
"""Shape: `(batch_size * num_images, H_mask, W_mask)`"""
426+
image_attention_mask: Annotated[
427+
torch.Tensor,
428+
TensorShape("bn", "nc", 32, 32), # H_mask, W_mask
429+
]
417430

418431

419-
class Phi4MMAudioFeatureInputs(TypedDict):
432+
class Phi4MMAudioFeatureInputs(TensorSchema):
433+
"""
434+
Dimensions:
435+
- bn: Batch size * number of audios
436+
- t: Time frames (M)
437+
"""
438+
420439
type: Literal["audio_features"]
421-
data: Union[torch.Tensor, list[torch.Tensor]]
422-
"""Shape: `(batch_size * num_audios, 80, M)"""
440+
441+
data: Annotated[
442+
Union[torch.Tensor, list[torch.Tensor]],
443+
TensorShape("bn", "t", 80, dynamic_dims={"t"}),
444+
]
423445

424446

425-
class Phi4MMAudioEmbeddingInputs(TypedDict):
447+
class Phi4MMAudioEmbeddingInputs(TensorSchema):
448+
"""
449+
Dimensions:
450+
- b: Batch size
451+
- n: Number of audios
452+
- f: Audio feature size
453+
- h: Hidden size (must match language model backbone)
454+
"""
426455
type: Literal["audio_embeds"]
427-
data: NestedTensors
428-
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
456+
data: Annotated[
457+
NestedTensors,
458+
TensorShape("b", "n", "f", "h"),
459+
]
429460

430461

431462
Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs]
@@ -985,18 +1016,10 @@ def _parse_and_validate_audio_input(
9851016
return None
9861017

9871018
if audio_features is not None:
988-
if not isinstance(audio_features, (torch.Tensor, list)):
989-
raise ValueError("Incorrect type of audio features. "
990-
f"Got type: {type(audio_features)}")
991-
9921019
return Phi4MMAudioFeatureInputs(type="audio_features",
9931020
data=flatten_bn(audio_features))
9941021

9951022
if audio_embeds is not None:
996-
if not isinstance(audio_embeds, (torch.Tensor, list)):
997-
raise ValueError("Incorrect type of audio embeds. "
998-
f"Got type: {type(audio_embeds)}")
999-
10001023
return Phi4MMAudioEmbeddingInputs(type="audio_embeds",
10011024
data=audio_embeds)
10021025

@@ -1074,7 +1097,7 @@ def _parse_and_validate_image_input(
10741097
elif isinstance(image_sizes, torch.Tensor):
10751098
image_sizes = image_sizes.flatten(0, 1)
10761099
else:
1077-
raise ValueError("Incorrect image_attention_mask inputs")
1100+
raise ValueError("Incorrect image_sizes inputs")
10781101

10791102
if isinstance(num_img_tokens, list):
10801103
num_img_tokens = [
@@ -1084,7 +1107,7 @@ def _parse_and_validate_image_input(
10841107
elif isinstance(num_img_tokens, torch.Tensor):
10851108
num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
10861109
else:
1087-
raise ValueError("Incorrect image_attention_mask inputs")
1110+
raise ValueError("Incorrect num_img_tokens inputs")
10881111

10891112
return Phi4MMImagePixelInputs(
10901113
type="pixel_values",

0 commit comments

Comments
 (0)