Skip to content

Commit 4678503

Browse files
authored
Migrate MiniCPMVImageInputs to TensorSchema (#21939)
Signed-off-by: Benji Beck <[email protected]>
1 parent 93d0652 commit 4678503

File tree

1 file changed

+36
-29
lines changed

1 file changed

+36
-29
lines changed

vllm/model_executor/models/minicpmv.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from collections import defaultdict
2828
from collections.abc import Iterable, Mapping, Sequence
2929
from functools import partial
30-
from typing import Any, Callable, Literal, Optional, TypedDict, Union
30+
from typing import Annotated, Any, Callable, Literal, Optional, Union
3131

3232
import numpy as np
3333
import torch
@@ -63,6 +63,7 @@
6363
from vllm.platforms import current_platform
6464
from vllm.sequence import IntermediateTensors
6565
from vllm.utils import flatten_2d_lists
66+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
6667

6768
from .idefics2_vision_model import Idefics2VisionTransformer
6869
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
@@ -74,36 +75,47 @@
7475
_MAX_FRAMES_PER_VIDEO = 16
7576

7677

77-
class MiniCPMVImagePixelInputs(TypedDict):
78-
type: Literal["pixel_values"]
79-
pixel_values: list[torch.Tensor]
78+
class MiniCPMVImagePixelInputs(TensorSchema):
8079
"""
81-
Shape: `(batch_size * num_images * num_slices, num_channels, height, width)`
82-
83-
Note that the image size may vary, so we pass it as a list
84-
instead of a batched tensor.
80+
Dimensions:
81+
- bns: Batch size * number of images * number of slices
82+
- bn: Batch size * number of images
83+
- c: Number of channels
84+
- h: Height
85+
- w: Width
8586
"""
8687

87-
tgt_sizes: torch.Tensor
88+
type: Literal["pixel_values"] = "pixel_values"
89+
90+
# Note that the image size may vary, so we pass it as a list instead of a
91+
# batched tensor.
92+
pixel_values: Annotated[
93+
list[torch.Tensor],
94+
TensorShape("bns", "c", "h", "w"),
95+
]
96+
tgt_sizes: Annotated[
97+
torch.Tensor,
98+
TensorShape("bns", 2), # This should be in `(height, width)` format.
99+
]
100+
num_slices: Annotated[
101+
torch.Tensor,
102+
TensorShape("bn"),
103+
]
104+
105+
106+
class MiniCPMVImageEmbeddingInputs(TensorSchema):
88107
"""
89-
Shape: `(batch_size * num_images * num_slices, 2)`
90-
91-
This should be in `(height, width)` format.
108+
Dimensions:
109+
- bn: Batch size * number of images
110+
- ns: Number of slices
111+
- hs: Hidden size (must match language model backbone)
92112
"""
93113

94-
num_slices: torch.Tensor
95-
"""Shape: `(batch_size * num_images)`"""
96-
97-
98-
class MiniCPMVImageEmbeddingInputs(TypedDict):
99114
type: Literal["image_embeds"]
100-
image_embeds: Union[torch.Tensor, list[torch.Tensor]]
101-
"""
102-
Shape: `(batch_size * num_images, num_slices, hidden_size)`
103-
104-
`hidden_size` must match the hidden size of language model backbone.
105-
instead of a batched tensor.
106-
"""
115+
image_embeds: Annotated[
116+
Union[torch.Tensor, list[torch.Tensor]],
117+
TensorShape("bn", "ns", "hs"),
118+
]
107119

108120

109121
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
@@ -832,11 +844,6 @@ def _parse_and_validate_vision_input(
832844
pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values))
833845
tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True)
834846

835-
if len(pixel_values_flat) != len(tgt_sizes_flat):
836-
raise ValueError("Inconsistent flattened lengths, found: "
837-
f"{len(pixel_values_flat)} vs. "
838-
f"{len(tgt_sizes_flat)}")
839-
840847
return MiniCPMVImagePixelInputs(
841848
type="pixel_values",
842849
pixel_values=pixel_values_flat,

0 commit comments

Comments
 (0)