3
3
4
4
import math
5
5
from collections .abc import Iterable , Mapping , Sequence
6
- from typing import Literal , Optional , TypedDict , Union
6
+ from typing import Annotated , Literal , Optional , Union
7
7
8
8
import torch
9
9
import torch .nn as nn
25
25
from vllm .multimodal .profiling import BaseDummyInputsBuilder
26
26
from vllm .sequence import IntermediateTensors
27
27
from vllm .utils import is_list_of
28
+ from vllm .utils .tensor_schema import TensorSchema , TensorShape
28
29
29
30
from .interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsPP
30
31
from .llava import init_vision_tower_for_llava
35
36
from .vision import get_vision_encoder_info
36
37
37
38
38
- class LlavaNextVideoPixelInputs (TypedDict ):
39
- type : Literal ["pixel_values_videos" ]
40
- data : Union [torch .Tensor , list [torch .Tensor ]]
41
- """
42
- Shape: `(batch_size, num_frames, num_channels, height, width)`
39
+ class LlavaNextVideoPixelInputs (TensorSchema ):
40
+ """
41
+ Dimensions:
42
+ - bs: Batch size
43
+ - nv: Number of videos
44
+ - nf: Number of frames
45
+ - nc: Number of channels (3)
46
+ - h: Height of each frame
47
+ - w: Width of each frame
43
48
44
49
Note that `num_frames` may be different for each batch, in which case
45
50
the data is passed as a list instead of a batched tensor.
46
51
47
52
Note that it only supports one video input for one batch.
48
53
"""
54
+ type : Literal ["pixel_values_videos" ] = "pixel_values_videos"
55
+
56
+ data : Annotated [Union [torch .Tensor , list [torch .Tensor ]],
57
+ TensorShape ("bs" , "nv" , "nf" , 3 , "h" , "w" )]
49
58
50
59
51
60
class LlavaNextVideoProcessingInfo (BaseProcessingInfo ):
@@ -320,27 +329,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
320
329
self .make_empty_intermediate_tensors = (
321
330
self .language_model .model .make_empty_intermediate_tensors )
322
331
323
- def _validate_video_pixel_values (
324
- self , data : Union [torch .Tensor , list [torch .Tensor ]]
325
- ) -> Union [torch .Tensor , list [torch .Tensor ]]:
326
-
327
- h = w = self .config .vision_config .image_size
328
- expected_dims = (3 , h , w )
329
-
330
- def _validate_shape (d : torch .Tensor ):
331
- actual_dims = tuple (d .shape [2 :])
332
-
333
- if actual_dims != expected_dims :
334
- expected_expr = ("num_frames" , * map (str , expected_dims ))
335
- raise ValueError (
336
- "The expected shape of pixel values in each video frame "
337
- f"is { expected_expr } . You supplied { tuple (d .shape )} ." )
338
-
339
- for d in data :
340
- _validate_shape (d )
341
-
342
- return data
343
-
344
332
def _parse_and_validate_video_input (
345
333
self , ** kwargs : object ) -> Optional [LlavaNextVideoPixelInputs ]:
346
334
"""
@@ -355,14 +343,13 @@ def _parse_and_validate_video_input(
355
343
if pixel_values_videos is None :
356
344
return None
357
345
358
- if not isinstance (pixel_values_videos , (torch .Tensor , list )):
359
- raise ValueError ("Incorrect type of pixel_values_videos. "
360
- f"Got type: { type (pixel_values_videos )} " )
361
-
362
- return LlavaNextVideoPixelInputs (
363
- type = "pixel_values_videos" ,
364
- data = pixel_values_videos ,
365
- )
346
+ expected_h = expected_w = self .config .vision_config .image_size
347
+ return LlavaNextVideoPixelInputs (type = "pixel_values_videos" ,
348
+ data = pixel_values_videos ,
349
+ resolve_bindings = {
350
+ "h" : expected_h ,
351
+ "w" : expected_w ,
352
+ })
366
353
367
354
def _select_image_features (self , image_features : torch .Tensor , * ,
368
355
strategy : str ) -> torch .Tensor :
0 commit comments