Skip to content

Commit 06da44f

Browse files
authored
Migrate LlavaImageInputs to TensorSchema (#21770)
Signed-off-by: Benji Beck <[email protected]>
1 parent a554991 commit 06da44f

File tree

1 file changed

+35
-32
lines changed

1 file changed

+35
-32
lines changed

vllm/model_executor/models/llava.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from abc import abstractmethod
55
from collections.abc import Iterable, Mapping, Sequence
6-
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
6+
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
77
Union, cast)
88

99
import torch
@@ -33,6 +33,7 @@
3333
PromptUpdateDetails)
3434
from vllm.multimodal.profiling import BaseDummyInputsBuilder
3535
from vllm.sequence import IntermediateTensors
36+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
3637

3738
from .clip import CLIPVisionModel
3839
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@@ -44,35 +45,46 @@
4445
from .vision import get_vision_encoder_info
4546

4647

47-
class LlavaImagePixelInputs(TypedDict):
48-
type: Literal["pixel_values"]
49-
pixel_values: torch.Tensor
48+
class LlavaImagePixelInputs(TensorSchema):
5049
"""
51-
Shape: `(batch_size * num_images, num_channels, height, width)`
52-
50+
Dimensions:
51+
- bn: Batch size * number of images
52+
- c: Number of channels (3)
53+
- h: Height
54+
- w: Width
55+
5356
Note that `height` or `width` may be different per batch and image,
5457
in which case the data is passed as a list instead of a batched tensor.
5558
"""
59+
type: Literal["pixel_values"] = "pixel_values"
60+
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
5661

5762

58-
class PixtralHFImagePixelInputs(TypedDict):
59-
type: Literal["pixel_values_pixtral"]
60-
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
63+
class PixtralHFImagePixelInputs(TensorSchema):
6164
"""
62-
Shape: `(batch_size * num_images, num_channels, height, width)`
63-
65+
Dimensions:
66+
- bn: Batch size * number of images
67+
- c: Number of channels
68+
- h: Height
69+
- w: Width
70+
6471
Note that `height` or `width` may be different per batch and image,
6572
in which case the data is passed as a list instead of a batched tensor.
6673
"""
74+
type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
75+
pixel_values: Annotated[Union[torch.Tensor, list[torch.Tensor]],
76+
TensorShape("bn", "c", "h", "w")]
6777

6878

69-
class LlavaImageEmbeddingInputs(TypedDict):
70-
type: Literal["image_embeds"]
71-
data: torch.Tensor
72-
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
73-
74-
`hidden_size` must match the hidden size of language model backbone.
79+
class LlavaImageEmbeddingInputs(TensorSchema):
7580
"""
81+
Dimensions:
82+
- bn: Batch size * number of images
83+
- ifs: Image feature size
84+
- hs: Hidden size (must match language model backbone)
85+
"""
86+
type: Literal["image_embeds"] = "image_embeds"
87+
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
7688

7789

7890
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
@@ -547,19 +559,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
547559
self.make_empty_intermediate_tensors = (
548560
self.language_model.make_empty_intermediate_tensors)
549561

550-
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
551-
h = w = self.config.vision_config.image_size
552-
expected_dims = (3, h, w)
553-
actual_dims = tuple(data.shape[1:])
554-
555-
if actual_dims != expected_dims:
556-
expected_expr = ("batch_size", *map(str, expected_dims))
557-
raise ValueError(
558-
f"The expected shape of pixel values is {expected_expr}. "
559-
f"You supplied {tuple(data.shape)}.")
560-
561-
return data
562-
563562
def _parse_and_validate_image_input(
564563
self, **kwargs: object) -> Optional[LlavaImageInputs]:
565564
pixel_values = kwargs.pop("pixel_values", None)
@@ -579,10 +578,14 @@ def _parse_and_validate_image_input(
579578
pixel_values=flatten_bn(pixel_values),
580579
)
581580

581+
expected_h = expected_w = self.config.vision_config.image_size
582582
return LlavaImagePixelInputs(
583583
type="pixel_values",
584-
pixel_values=self._validate_pixel_values(
585-
flatten_bn(pixel_values, concat=True)),
584+
pixel_values=flatten_bn(pixel_values, concat=True),
585+
resolve_bindings={
586+
"h": expected_h,
587+
"w": expected_w
588+
},
586589
)
587590

588591
if image_embeds is not None:

0 commit comments

Comments
 (0)