3
3
4
4
from abc import abstractmethod
5
5
from collections .abc import Iterable , Mapping , Sequence
6
- from typing import (Final , Literal , Optional , Protocol , TypedDict , TypeVar ,
6
+ from typing import (Annotated , Final , Literal , Optional , Protocol , TypeVar ,
7
7
Union , cast )
8
8
9
9
import torch
33
33
PromptUpdateDetails )
34
34
from vllm .multimodal .profiling import BaseDummyInputsBuilder
35
35
from vllm .sequence import IntermediateTensors
36
+ from vllm .utils .tensor_schema import TensorSchema , TensorShape
36
37
37
38
from .clip import CLIPVisionModel
38
39
from .interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsPP
44
45
from .vision import get_vision_encoder_info
45
46
46
47
47
- class LlavaImagePixelInputs (TypedDict ):
48
- type : Literal ["pixel_values" ]
49
- pixel_values : torch .Tensor
48
+ class LlavaImagePixelInputs (TensorSchema ):
50
49
"""
51
- Shape: `(batch_size * num_images, num_channels, height, width)`
52
-
50
+ Dimensions:
51
+ - bn: Batch size * number of images
52
+ - c: Number of channels (3)
53
+ - h: Height
54
+ - w: Width
55
+
53
56
Note that `height` or `width` may be different per batch and image,
54
57
in which case the data is passed as a list instead of a batched tensor.
55
58
"""
59
+ type : Literal ["pixel_values" ] = "pixel_values"
60
+ pixel_values : Annotated [torch .Tensor , TensorShape ("bn" , 3 , "h" , "w" )]
56
61
57
62
58
- class PixtralHFImagePixelInputs (TypedDict ):
59
- type : Literal ["pixel_values_pixtral" ]
60
- pixel_values : Union [torch .Tensor , list [torch .Tensor ]]
63
+ class PixtralHFImagePixelInputs (TensorSchema ):
61
64
"""
62
- Shape: `(batch_size * num_images, num_channels, height, width)`
63
-
65
+ Dimensions:
66
+ - bn: Batch size * number of images
67
+ - c: Number of channels
68
+ - h: Height
69
+ - w: Width
70
+
64
71
Note that `height` or `width` may be different per batch and image,
65
72
in which case the data is passed as a list instead of a batched tensor.
66
73
"""
74
+ type : Literal ["pixel_values_pixtral" ] = "pixel_values_pixtral"
75
+ pixel_values : Annotated [Union [torch .Tensor , list [torch .Tensor ]],
76
+ TensorShape ("bn" , "c" , "h" , "w" )]
67
77
68
78
69
- class LlavaImageEmbeddingInputs (TypedDict ):
70
- type : Literal ["image_embeds" ]
71
- data : torch .Tensor
72
- """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
73
-
74
- `hidden_size` must match the hidden size of language model backbone.
79
+ class LlavaImageEmbeddingInputs (TensorSchema ):
75
80
"""
81
+ Dimensions:
82
+ - bn: Batch size * number of images
83
+ - ifs: Image feature size
84
+ - hs: Hidden size (must match language model backbone)
85
+ """
86
+ type : Literal ["image_embeds" ] = "image_embeds"
87
+ data : Annotated [torch .Tensor , TensorShape ("bn" , "ifs" , "hs" )]
76
88
77
89
78
90
LlavaImageInputs = Union [LlavaImagePixelInputs , PixtralHFImagePixelInputs ,
@@ -547,19 +559,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
547
559
self .make_empty_intermediate_tensors = (
548
560
self .language_model .make_empty_intermediate_tensors )
549
561
550
- def _validate_pixel_values (self , data : torch .Tensor ) -> torch .Tensor :
551
- h = w = self .config .vision_config .image_size
552
- expected_dims = (3 , h , w )
553
- actual_dims = tuple (data .shape [1 :])
554
-
555
- if actual_dims != expected_dims :
556
- expected_expr = ("batch_size" , * map (str , expected_dims ))
557
- raise ValueError (
558
- f"The expected shape of pixel values is { expected_expr } . "
559
- f"You supplied { tuple (data .shape )} ." )
560
-
561
- return data
562
-
563
562
def _parse_and_validate_image_input (
564
563
self , ** kwargs : object ) -> Optional [LlavaImageInputs ]:
565
564
pixel_values = kwargs .pop ("pixel_values" , None )
@@ -579,10 +578,14 @@ def _parse_and_validate_image_input(
579
578
pixel_values = flatten_bn (pixel_values ),
580
579
)
581
580
581
+ expected_h = expected_w = self .config .vision_config .image_size
582
582
return LlavaImagePixelInputs (
583
583
type = "pixel_values" ,
584
- pixel_values = self ._validate_pixel_values (
585
- flatten_bn (pixel_values , concat = True )),
584
+ pixel_values = flatten_bn (pixel_values , concat = True ),
585
+ resolve_bindings = {
586
+ "h" : expected_h ,
587
+ "w" : expected_w
588
+ },
586
589
)
587
590
588
591
if image_embeds is not None :
0 commit comments