4
4
5
5
from collections .abc import Iterable , Mapping
6
6
from functools import partial
7
- from typing import Literal , Optional , TypedDict , Union
7
+ from typing import Annotated , Literal , Optional , Union
8
8
9
9
import torch
10
10
import torch .nn as nn
14
14
from vllm .config .multimodal import BaseDummyOptions
15
15
from vllm .model_executor .layers .linear import ReplicatedLinear
16
16
from vllm .model_executor .layers .quantization import QuantizationConfig
17
- from vllm .model_executor .models .ovis import OvisImagePatchInputs , VisualEmbedding
17
+ from vllm .model_executor .models .ovis import VisualEmbedding
18
18
from vllm .model_executor .models .siglip2navit import Siglip2NavitModel
19
19
from vllm .model_executor .models .utils import (
20
20
AutoWeightsLoader ,
37
37
from vllm .multimodal .profiling import BaseDummyInputsBuilder
38
38
from vllm .sequence import IntermediateTensors
39
39
from vllm .transformers_utils .processors .ovis2_5 import Ovis2_5Processor
40
+ from vllm .utils .tensor_schema import TensorSchema , TensorShape
40
41
41
42
from .interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsPP
42
43
58
59
}
59
60
60
61
61
- class OvisVideoPatchInputs (TypedDict ):
62
- type : Literal ["video_patches" ]
63
- flat_data : torch .Tensor
62
+ class Ovis2_5ImagePatchInputs (TensorSchema ):
64
63
"""
65
- Shape:
66
- `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
64
+ Dimensions:
65
+ - bnp: Batch size * number of images * number of patches
66
+ - patch_size: patch_size_x * patch_size_y * num_channels
67
+ - patch_indicators: Batch size * (number of patches + 1)
68
+ - bn: Batch size * number of images
67
69
"""
68
70
69
- indicator_tokens : torch .Tensor
70
- """
71
- Shape:
72
- `(batch_size * (num_patches + 1))`
73
- """
71
+ type : Literal ["image_patches" ]
72
+ flat_data : Annotated [torch .Tensor , TensorShape ("bnp" , "patch_size" )]
73
+ indicator_tokens : Annotated [torch .Tensor , TensorShape ("patch_indicators" )]
74
+ patches_per_item : Annotated [list [int ], TensorShape ("bn" )]
75
+ grids : Annotated [torch .Tensor , TensorShape ("bn" , 3 )]
76
+ # This is used to restore the first two dimensions of `flat_data`.
77
+
74
78
75
- patches_per_image : list [ int ]
79
+ class Ovis2_5VideoPatchInputs ( TensorSchema ):
76
80
"""
77
- List of number of total patches for each frame in the video.
78
- This is used to restore the first two dimensions of `flat_data`.
81
+ Dimensions:
82
+ - bnp: Batch size * number of videos * number of patches
83
+ - patch_size: patch_size_x * patch_size_y * num_channels
84
+ - patch_indicators: Batch size * (number of patches + 1)
85
+ - bn: Batch size * number of videos
79
86
"""
80
87
81
-
82
- def _ovis2_5_field_config ():
83
- return dict (
84
- pixel_values = MultiModalFieldConfig .batched ("image" ),
85
- grids = MultiModalFieldConfig .batched ("image" ),
86
- indicator_tokens = MultiModalFieldConfig .batched ("image" ),
87
- video_pixel_values = MultiModalFieldConfig .batched ("video" ),
88
- video_indicator_tokens = MultiModalFieldConfig .batched ("video" ),
89
- video_grids = MultiModalFieldConfig .batched ("video" ),
90
- )
88
+ type : Literal ["video_patches" ]
89
+ flat_data : Annotated [torch .Tensor , TensorShape ("bnp" , "patch_size" )]
90
+ indicator_tokens : Annotated [torch .Tensor , TensorShape ("patch_indicators" )]
91
+ patches_per_item : Annotated [list [int ], TensorShape ("bn" )]
92
+ grids : Annotated [torch .Tensor , TensorShape ("bn" , 3 )]
93
+ # This is used to restore the first two dimensions of `flat_data`.
91
94
92
95
93
96
class VisualTokenizer (torch .nn .Module ):
@@ -380,7 +383,7 @@ def _call_hf_processor(
380
383
self .visual_indicators_to_visual_tokens (indicator )
381
384
for indicator in visual_indicators
382
385
]
383
- processed_outputs ["video_indicator_tokens" ] = indicator_tokens
386
+ processed_outputs ["video_indicator_tokens" ] = torch . tensor ( indicator_tokens )
384
387
if "images" in mm_data :
385
388
visual_indicators = [
386
389
hf_processor .construct_visual_indicators ((1 , 1 , 1 ), False )
@@ -391,7 +394,7 @@ def _call_hf_processor(
391
394
for indicator in visual_indicators
392
395
]
393
396
394
- processed_outputs ["indicator_tokens" ] = indicator_tokens
397
+ processed_outputs ["indicator_tokens" ] = torch . tensor ( indicator_tokens )
395
398
return processed_outputs
396
399
397
400
def _apply_hf_processor_tokens_only (
@@ -405,7 +408,14 @@ def _get_mm_fields_config(
405
408
hf_inputs : BatchFeature ,
406
409
hf_processor_mm_kwargs : Mapping [str , object ],
407
410
) -> Mapping [str , MultiModalFieldConfig ]:
408
- return _ovis2_5_field_config ()
411
+ return dict (
412
+ pixel_values = MultiModalFieldConfig .batched ("image" ),
413
+ grids = MultiModalFieldConfig .batched ("image" ),
414
+ indicator_tokens = MultiModalFieldConfig .batched ("image" ),
415
+ video_pixel_values = MultiModalFieldConfig .batched ("video" ),
416
+ video_indicator_tokens = MultiModalFieldConfig .batched ("video" ),
417
+ video_grids = MultiModalFieldConfig .batched ("video" ),
418
+ )
409
419
410
420
def _get_prompt_updates (
411
421
self ,
@@ -441,6 +451,8 @@ def get_replacement_ovis(item_idx, modality: str):
441
451
dummy_inputs = Ovis2_5DummyInputsBuilder ,
442
452
)
443
453
class Ovis2_5 (nn .Module , SupportsMultiModal , SupportsPP ):
454
+ merge_by_field_config = True
455
+
444
456
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
445
457
super ().__init__ ()
446
458
config = vllm_config .model_config .hf_config
@@ -470,7 +482,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
470
482
471
483
def _parse_and_validate_image_input (
472
484
self , ** kwargs : object
473
- ) -> Optional [OvisImagePatchInputs ]:
485
+ ) -> Optional [Ovis2_5ImagePatchInputs ]:
474
486
pixel_values = kwargs .pop ("pixel_values" , None )
475
487
indicator_tokens = kwargs .pop ("indicator_tokens" , None )
476
488
grids = kwargs .pop ("grids" , None )
@@ -489,22 +501,22 @@ def _parse_and_validate_image_input(
489
501
f"Got type: { type (indicator_tokens )} "
490
502
)
491
503
492
- return OvisImagePatchInputs (
504
+ return Ovis2_5ImagePatchInputs (
493
505
type = "image_patches" ,
494
- flat_data = flatten_bn (flatten_bn ( pixel_values ) , concat = True ),
495
- patches_per_image = [
506
+ flat_data = flatten_bn (pixel_values , concat = True ),
507
+ patches_per_item = [
496
508
x .shape [0 ] // (self .config .vit_config .hidden_stride ** 2 )
497
- for x in flatten_bn ( pixel_values )
509
+ for x in pixel_values
498
510
],
499
- indicator_tokens = flatten_bn (flatten_bn ( indicator_tokens ) , concat = True ),
500
- grids = flatten_bn (flatten_bn ( grids ) , concat = True ),
511
+ indicator_tokens = flatten_bn (indicator_tokens , concat = True ),
512
+ grids = flatten_bn (grids , concat = True ),
501
513
)
502
514
503
515
raise AssertionError ("This line should be unreachable." )
504
516
505
517
def _parse_and_validate_video_input (
506
518
self , ** kwargs : object
507
- ) -> Optional [OvisImagePatchInputs ]:
519
+ ) -> Optional [Ovis2_5VideoPatchInputs ]:
508
520
pixel_values = kwargs .pop ("video_pixel_values" , None )
509
521
indicator_tokens = kwargs .pop ("video_indicator_tokens" , None )
510
522
grids = kwargs .pop ("video_grids" , None )
@@ -523,26 +535,26 @@ def _parse_and_validate_video_input(
523
535
f"Got type: { type (indicator_tokens )} "
524
536
)
525
537
526
- return OvisVideoPatchInputs (
538
+ return Ovis2_5VideoPatchInputs (
527
539
type = "video_patches" ,
528
- flat_data = flatten_bn (flatten_bn ( pixel_values ) , concat = True ),
529
- patches_per_image = [
540
+ flat_data = flatten_bn (pixel_values , concat = True ),
541
+ patches_per_item = [
530
542
x .shape [0 ] // (self .config .vit_config .hidden_stride ** 2 )
531
- for x in flatten_bn ( pixel_values )
543
+ for x in pixel_values
532
544
],
533
- indicator_tokens = flatten_bn (flatten_bn ( indicator_tokens ) , concat = True ),
534
- grids = flatten_bn (flatten_bn ( grids ) , concat = True ),
545
+ indicator_tokens = flatten_bn (indicator_tokens , concat = True ),
546
+ grids = flatten_bn (grids , concat = True ),
535
547
)
536
548
537
549
raise AssertionError ("This line should be unreachable." )
538
550
539
- def _process_image_input (
540
- self , image_input : Union [OvisImagePatchInputs , OvisVideoPatchInputs ]
551
+ def _process_visual_input (
552
+ self , visual_input : Union [Ovis2_5ImagePatchInputs , Ovis2_5VideoPatchInputs ]
541
553
) -> MultiModalEmbeddings :
542
- image_patches_flat = image_input ["flat_data" ]
543
- patches_per_image = image_input [ "patches_per_image " ]
544
- indicator_tokens = image_input ["indicator_tokens" ]
545
- grid_thws = image_input ["grids" ]
554
+ image_patches_flat = visual_input ["flat_data" ]
555
+ patches_per_image = visual_input [ "patches_per_item " ]
556
+ indicator_tokens = visual_input ["indicator_tokens" ]
557
+ grid_thws = visual_input ["grids" ]
546
558
547
559
indicator_per_image = list (
548
560
map (lambda x : 2 if x > 1 else x + 2 , patches_per_image )
@@ -604,11 +616,11 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
604
616
for modality in modalities :
605
617
if modality == "images" :
606
618
image_input = modalities ["images" ]
607
- vision_embeddings = self ._process_image_input (image_input )
619
+ vision_embeddings = self ._process_visual_input (image_input )
608
620
multimodal_embeddings += vision_embeddings
609
621
if modality == "videos" :
610
622
video_input = modalities ["videos" ]
611
- video_embeddings = self ._process_image_input (video_input )
623
+ video_embeddings = self ._process_visual_input (video_input )
612
624
multimodal_embeddings += video_embeddings
613
625
614
626
return multimodal_embeddings
0 commit comments