Skip to content

Commit b4e2916

Browse files
Migrate LlavaNextImageInputs to TensorSchema (#21774)
Signed-off-by: Benji Beck <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 65a7917 commit b4e2916

File tree

2 files changed

+34
-63
lines changed

2 files changed

+34
-63
lines changed

vllm/model_executor/models/llava_next.py

Lines changed: 31 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,22 @@
33

44
from abc import abstractmethod
55
from collections.abc import Iterable, Mapping
6-
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
6+
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
77
Union)
88

99
import torch
1010
import torch.nn as nn
1111
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
1212
from transformers.models.llava_next.modeling_llava_next import (
1313
get_anyres_image_grid_shape, unpad_image)
14-
from typing_extensions import NotRequired
1514

1615
from vllm.config import VllmConfig
1716
from vllm.model_executor.sampling_metadata import SamplingMetadata
1817
from vllm.multimodal import MULTIMODAL_REGISTRY
1918
from vllm.multimodal.inputs import MultiModalFieldConfig
2019
from vllm.multimodal.parse import ImageSize
2120
from vllm.sequence import IntermediateTensors
21+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
2222

2323
from .clip import CLIPVisionModel
2424
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@@ -30,32 +30,36 @@
3030
flatten_bn, init_vllm_registered_model, maybe_prefix)
3131

3232

33-
class LlavaNextImagePixelInputs(TypedDict):
34-
type: Literal["pixel_values"]
35-
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
33+
class LlavaNextImagePixelInputs(TensorSchema):
3634
"""
37-
Shape:
38-
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
39-
35+
Dimensions:
36+
- bn: Batch size * number of images
37+
- np: Number of patches + 1
38+
- c: Number of channels (3)
39+
- h: Height
40+
- w: Width
41+
4042
Note that `num_patches` may be different per batch and image,
4143
in which case the data is passed as a list instead of a batched tensor.
4244
"""
45+
type: Literal["pixel_values"] = "pixel_values"
46+
pixel_values: Annotated[
47+
Union[torch.Tensor, list[torch.Tensor]],
48+
TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"})]
4349

44-
image_sizes: NotRequired[torch.Tensor]
45-
"""
46-
Shape: `(batch_size * num_images, 2)`
47-
48-
This should be in `(height, width)` format.
49-
"""
50-
50+
image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
51+
# This should be in `(height, width)` format.
5152

52-
class LlavaNextImageEmbeddingInputs(TypedDict):
53-
type: Literal["image_embeds"]
54-
data: torch.Tensor
55-
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
5653

57-
`hidden_size` must match the hidden size of language model backbone.
54+
class LlavaNextImageEmbeddingInputs(TensorSchema):
55+
"""
56+
Dimensions:
57+
- bn: Batch size * number of images
58+
- ifs: Image feature size
59+
- hs: Hidden size (must match language model backbone)
5860
"""
61+
type: Literal["image_embeds"] = "image_embeds"
62+
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
5963

6064

6165
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
@@ -269,44 +273,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
269273
self.make_empty_intermediate_tensors = (
270274
self.language_model.make_empty_intermediate_tensors)
271275

272-
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
273-
expected_dims = (2, )
274-
275-
def _validate_shape(d: torch.Tensor):
276-
actual_dims = tuple(d.shape)
277-
278-
if actual_dims != expected_dims:
279-
expected_expr = str(expected_dims)
280-
raise ValueError(
281-
f"The expected shape of image sizes per image per batch "
282-
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
283-
284-
for d in data:
285-
_validate_shape(d)
286-
287-
return data
288-
289-
def _validate_pixel_values(
290-
self, data: Union[torch.Tensor, list[torch.Tensor]]
291-
) -> Union[torch.Tensor, list[torch.Tensor]]:
292-
293-
h = w = self.config.vision_config.image_size
294-
expected_dims = (3, h, w)
295-
296-
def _validate_shape(d: torch.Tensor):
297-
actual_dims = tuple(d.shape[1:])
298-
299-
if actual_dims != expected_dims:
300-
expected_expr = ("num_patches", *map(str, expected_dims))
301-
raise ValueError(
302-
"The expected shape of pixel values per image per batch "
303-
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
304-
305-
for d in data:
306-
_validate_shape(d)
307-
308-
return data
309-
310276
def _parse_and_validate_image_input(
311277
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
312278
pixel_values = kwargs.pop("pixel_values", None)
@@ -325,13 +291,15 @@ def _parse_and_validate_image_input(
325291
raise ValueError("Incorrect type of image sizes. "
326292
f"Got type: {type(image_sizes)}")
327293

294+
expected_h = expected_w = self.config.vision_config.image_size
328295
return LlavaNextImagePixelInputs(
329296
type="pixel_values",
330-
pixel_values=self._validate_pixel_values(
331-
flatten_bn(pixel_values)),
332-
image_sizes=self._validate_image_sizes(
333-
flatten_bn(image_sizes, concat=True)),
334-
)
297+
pixel_values=flatten_bn(pixel_values),
298+
image_sizes=flatten_bn(image_sizes, concat=True),
299+
resolve_bindings={
300+
"h": expected_h,
301+
"w": expected_w,
302+
})
335303

336304
if image_embeds is not None:
337305
if not isinstance(image_embeds, torch.Tensor):

vllm/utils/tensor_schema.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def __init__(self,
6060
def __getitem__(self, item) -> Any:
6161
return getattr(self, item)
6262

63+
def get(self, item, default=None) -> Any:
64+
return getattr(self, item, default)
65+
6366
def _match_shape_with_dynamic(self, actual: tuple[int, ...],
6467
reference: tuple[int, ...],
6568
expected_shape: tuple[Union[int, str], ...],

0 commit comments

Comments
 (0)