3
3
4
4
from abc import abstractmethod
5
5
from collections .abc import Iterable , Mapping
6
- from typing import (Final , Literal , Optional , Protocol , TypedDict , TypeVar ,
6
+ from typing import (Annotated , Final , Literal , Optional , Protocol , TypeVar ,
7
7
Union )
8
8
9
9
import torch
10
10
import torch .nn as nn
11
11
from transformers import BatchFeature , LlavaNextConfig , LlavaNextProcessor
12
12
from transformers .models .llava_next .modeling_llava_next import (
13
13
get_anyres_image_grid_shape , unpad_image )
14
- from typing_extensions import NotRequired
15
14
16
15
from vllm .config import VllmConfig
17
16
from vllm .model_executor .sampling_metadata import SamplingMetadata
18
17
from vllm .multimodal import MULTIMODAL_REGISTRY
19
18
from vllm .multimodal .inputs import MultiModalFieldConfig
20
19
from vllm .multimodal .parse import ImageSize
21
20
from vllm .sequence import IntermediateTensors
21
+ from vllm .utils .tensor_schema import TensorSchema , TensorShape
22
22
23
23
from .clip import CLIPVisionModel
24
24
from .interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsPP
30
30
flatten_bn , init_vllm_registered_model , maybe_prefix )
31
31
32
32
33
- class LlavaNextImagePixelInputs (TypedDict ):
34
- type : Literal ["pixel_values" ]
35
- pixel_values : Union [torch .Tensor , list [torch .Tensor ]]
33
+ class LlavaNextImagePixelInputs (TensorSchema ):
36
34
"""
37
- Shape:
38
- `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
39
-
35
+ Dimensions:
36
+ - bn: Batch size * number of images
37
+ - np: Number of patches + 1
38
+ - c: Number of channels (3)
39
+ - h: Height
40
+ - w: Width
41
+
40
42
Note that `num_patches` may be different per batch and image,
41
43
in which case the data is passed as a list instead of a batched tensor.
42
44
"""
45
+ type : Literal ["pixel_values" ] = "pixel_values"
46
+ pixel_values : Annotated [
47
+ Union [torch .Tensor , list [torch .Tensor ]],
48
+ TensorShape ("bn" , "np" , 3 , "h" , "w" , dynamic_dims = {"np" })]
43
49
44
- image_sizes : NotRequired [torch .Tensor ]
45
- """
46
- Shape: `(batch_size * num_images, 2)`
47
-
48
- This should be in `(height, width)` format.
49
- """
50
-
50
+ image_sizes : Annotated [Optional [torch .Tensor ], TensorShape ("bn" , 2 )]
51
+ # This should be in `(height, width)` format.
51
52
52
- class LlavaNextImageEmbeddingInputs (TypedDict ):
53
- type : Literal ["image_embeds" ]
54
- data : torch .Tensor
55
- """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
56
53
57
- `hidden_size` must match the hidden size of language model backbone.
54
+ class LlavaNextImageEmbeddingInputs (TensorSchema ):
55
+ """
56
+ Dimensions:
57
+ - bn: Batch size * number of images
58
+ - ifs: Image feature size
59
+ - hs: Hidden size (must match language model backbone)
58
60
"""
61
+ type : Literal ["image_embeds" ] = "image_embeds"
62
+ data : Annotated [torch .Tensor , TensorShape ("bn" , "ifs" , "hs" )]
59
63
60
64
61
65
LlavaNextImageInputs = Union [LlavaNextImagePixelInputs ,
@@ -269,44 +273,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
269
273
self .make_empty_intermediate_tensors = (
270
274
self .language_model .make_empty_intermediate_tensors )
271
275
272
- def _validate_image_sizes (self , data : torch .Tensor ) -> torch .Tensor :
273
- expected_dims = (2 , )
274
-
275
- def _validate_shape (d : torch .Tensor ):
276
- actual_dims = tuple (d .shape )
277
-
278
- if actual_dims != expected_dims :
279
- expected_expr = str (expected_dims )
280
- raise ValueError (
281
- f"The expected shape of image sizes per image per batch "
282
- f"is { expected_expr } . You supplied { tuple (d .shape )} ." )
283
-
284
- for d in data :
285
- _validate_shape (d )
286
-
287
- return data
288
-
289
- def _validate_pixel_values (
290
- self , data : Union [torch .Tensor , list [torch .Tensor ]]
291
- ) -> Union [torch .Tensor , list [torch .Tensor ]]:
292
-
293
- h = w = self .config .vision_config .image_size
294
- expected_dims = (3 , h , w )
295
-
296
- def _validate_shape (d : torch .Tensor ):
297
- actual_dims = tuple (d .shape [1 :])
298
-
299
- if actual_dims != expected_dims :
300
- expected_expr = ("num_patches" , * map (str , expected_dims ))
301
- raise ValueError (
302
- "The expected shape of pixel values per image per batch "
303
- f"is { expected_expr } . You supplied { tuple (d .shape )} ." )
304
-
305
- for d in data :
306
- _validate_shape (d )
307
-
308
- return data
309
-
310
276
def _parse_and_validate_image_input (
311
277
self , ** kwargs : object ) -> Optional [LlavaNextImageInputs ]:
312
278
pixel_values = kwargs .pop ("pixel_values" , None )
@@ -325,13 +291,15 @@ def _parse_and_validate_image_input(
325
291
raise ValueError ("Incorrect type of image sizes. "
326
292
f"Got type: { type (image_sizes )} " )
327
293
294
+ expected_h = expected_w = self .config .vision_config .image_size
328
295
return LlavaNextImagePixelInputs (
329
296
type = "pixel_values" ,
330
- pixel_values = self ._validate_pixel_values (
331
- flatten_bn (pixel_values )),
332
- image_sizes = self ._validate_image_sizes (
333
- flatten_bn (image_sizes , concat = True )),
334
- )
297
+ pixel_values = flatten_bn (pixel_values ),
298
+ image_sizes = flatten_bn (image_sizes , concat = True ),
299
+ resolve_bindings = {
300
+ "h" : expected_h ,
301
+ "w" : expected_w ,
302
+ })
335
303
336
304
if image_embeds is not None :
337
305
if not isinstance (image_embeds , torch .Tensor ):
0 commit comments