30
30
IsHybrid ,
31
31
MultiModalEmbeddings ,
32
32
SupportsMultiModal ,
33
+ SupportsMultiModalPruning ,
33
34
)
34
35
from vllm .model_executor .models .internvl import (
35
36
calculate_internvl_targets ,
44
45
maybe_prefix ,
45
46
)
46
47
from vllm .multimodal import MULTIMODAL_REGISTRY
48
+ from vllm .multimodal .evs import (
49
+ compute_retained_tokens_count ,
50
+ compute_retention_mask ,
51
+ )
47
52
from vllm .multimodal .inputs import (
48
53
MultiModalDataDict ,
49
54
MultiModalFieldConfig ,
62
67
PromptReplacement ,
63
68
PromptUpdate ,
64
69
PromptUpdateDetails ,
70
+ _seq2tokens ,
65
71
)
66
72
from vllm .multimodal .profiling import BaseDummyInputsBuilder
67
73
from vllm .sequence import IntermediateTensors
68
74
from vllm .transformers_utils .configs .radio import RadioConfig
69
- from vllm .transformers_utils .tokenizer import AnyTokenizer
75
+ from vllm .transformers_utils .tokenizer import (
76
+ AnyTokenizer ,
77
+ cached_tokenizer_from_config ,
78
+ encode_tokens ,
79
+ )
70
80
from vllm .utils .tensor_schema import TensorSchema , TensorShape
71
81
82
+ from .utils import _merge_multimodal_embeddings
83
+
72
84
# Configure PIL to handle large images without warnings
73
85
# This prevents DecompressionBombWarning for legitimate large images
74
86
Image .MAX_IMAGE_PIXELS = None # Disable the limit entirely
@@ -382,6 +394,7 @@ def __init__(
382
394
max_dynamic_patch : Optional [int ] = None ,
383
395
dynamic_image_size : Optional [bool ] = None ,
384
396
video_token : Optional [str ] = None ,
397
+ video_pruning_rate : Optional [float ] = None ,
385
398
) -> None :
386
399
super ().__init__ (
387
400
config = config ,
@@ -392,6 +405,7 @@ def __init__(
392
405
)
393
406
# add extra video token for video processing
394
407
self .video_token = video_token
408
+ self .video_pruning_rate = video_pruning_rate
395
409
396
410
@property
397
411
def supports_video (self ) -> bool :
@@ -446,12 +460,38 @@ def _preprocess_video(
446
460
),
447
461
}
448
462
463
+ image_size : int = self .config .force_image_size
464
+ patch_size : int = self .config .patch_size
465
+ downsample_ratio = self .config .downsample_ratio
466
+ tokens_in_single_frame = int (
467
+ (image_size * image_size // patch_size ** 2 ) * (downsample_ratio ** 2 )
468
+ )
469
+
449
470
for pixel_values in pixel_values_lst_video :
450
- num_patches = pixel_values .shape [0 ]
471
+ num_frames = pixel_values .shape [0 ]
472
+
473
+ if (
474
+ self .video_pruning_rate is not None
475
+ and self .video_pruning_rate > 0.0
476
+ ):
477
+ # Start of EVS-specific code
478
+ num_tokens = compute_retained_tokens_count (
479
+ tokens_per_frame = tokens_in_single_frame ,
480
+ num_frames = num_frames ,
481
+ q = self .video_pruning_rate ,
482
+ )
483
+
484
+ # Here we just need placeholders that won't actually be replaced -
485
+ # we just need to make sure the total number of tokens is correct
486
+ # assign all tokens to the first frame
487
+ tokens_per_frame = [num_tokens ] + [0 ] * (num_frames - 1 )
488
+
489
+ # End of EVS-specific code
490
+ else :
491
+ tokens_per_frame = [tokens_in_single_frame ] * num_frames
492
+
493
+ video_repl = self .get_video_repl (tokens_per_frame , self .video_token )
451
494
452
- video_repl = self .get_video_repl (
453
- self .num_image_token , num_patches , self .video_token
454
- )
455
495
text = [t .replace ("<video>" , video_repl .full , 1 ) for t in text ]
456
496
return text , video_inputs
457
497
@@ -501,20 +541,40 @@ def get_image_repl(
501
541
502
542
return PromptUpdateDetails .select_text (repl_full , IMG_CONTEXT )
503
543
544
+ @classmethod
504
545
def get_video_repl (
505
- self ,
506
- feature_size : int ,
507
- num_patches : Optional [int ] = None ,
546
+ cls ,
547
+ tokens_per_frame : list [int ],
508
548
video_context_token : str = IMG_CONTEXT ,
509
549
) -> PromptUpdateDetails [str ]:
510
- repl_features = video_context_token * self .num_image_token
511
- repl_features_with_sep = IMG_START + repl_features + IMG_END
512
- # num_patches is equal to num_frames
550
+ """
551
+ Build prompt replacement for a video.
552
+ The replacement returned is not actually used to replace the placeholder
553
+ tokens - it's just used to make sure we allocate the correct number
554
+ of tokens.
555
+ Actual replacement is done in get_multimodal_embeddings of
556
+ NemotronH_Nano_VL_V2
557
+ (specifically in _process_video_input -> _create_final_video_embeddings).
558
+ There, we create the final embeddings with text embeddings for indicator tokens
559
+ and video embeddings for video tokens.
560
+ This is a single function that handles all cases - non EVS, EVS dummy, EVS real.
561
+ The differentiation is done via tokens_per_frame parameter.
562
+ - non EVS case - constant value same value across all frames
563
+ - EVS dummy - Doesn't matter how tokens are distributed between frames - just
564
+ make sure the total number of tokens is correct.
565
+ - EVS real (called from get_real_video_repl_for_evs) - different value per frame
566
+ Args:
567
+ tokens_per_frame (list[int]): number of tokens per frame
568
+ video_context_token (str): the token to use for the video context
569
+ """
513
570
repl_full = "" .join (
514
- [f"Frame{ i + 1 } : { repl_features_with_sep } " for i in range (num_patches )]
571
+ [
572
+ f"Frame{ i + 1 } : { IMG_START } { video_context_token * num_tokens } { IMG_END } "
573
+ for i , num_tokens in enumerate (tokens_per_frame )
574
+ ]
515
575
)
516
576
517
- return PromptUpdateDetails .select_text (repl_full , video_context_token )
577
+ return PromptUpdateDetails .from_seq (repl_full )
518
578
519
579
520
580
class BaseNanoNemotronVLProcessingInfo (BaseProcessingInfo ):
@@ -605,6 +665,9 @@ def get_supported_mm_limits(self):
605
665
def get_video_token (self ) -> Optional [str ]:
606
666
return IMG_CONTEXT
607
667
668
+ def get_video_pruning_rate (self ) -> Optional [float ]:
669
+ return self .ctx .get_mm_config ().video_pruning_rate
670
+
608
671
def get_num_frames_with_most_features (
609
672
self ,
610
673
seq_len : int ,
@@ -628,6 +691,7 @@ def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor:
628
691
config = self .get_hf_config (),
629
692
tokenizer = self .get_tokenizer (),
630
693
video_token = self .get_video_token (),
694
+ video_pruning_rate = self .get_video_pruning_rate (),
631
695
** kwargs ,
632
696
)
633
697
@@ -805,8 +869,26 @@ def get_video_replacement_internvl(item_idx: int):
805
869
if num_patches is not None :
806
870
assert isinstance (num_patches , int )
807
871
872
+ video_pruning_rate = self .info .ctx .get_mm_config ().video_pruning_rate
873
+ if video_pruning_rate is not None and video_pruning_rate > 0.0 :
874
+ # Start of EVS-specific code
875
+ num_tokens = compute_retained_tokens_count (
876
+ tokens_per_frame = feature_size ,
877
+ num_frames = num_patches ,
878
+ q = video_pruning_rate ,
879
+ )
880
+ # Here we just need placeholders that won't actually be replaced -
881
+ # we just need to make sure the total number of tokens is correct
882
+ # assign all tokens to the first frame
883
+ tokens_per_frame = [num_tokens ] + [0 ] * (num_patches - 1 )
884
+
885
+ # End of EVS-specific code
886
+ else :
887
+ tokens_per_frame = [feature_size ] * num_patches
888
+
808
889
return hf_processor .get_video_repl (
809
- feature_size , num_patches , video_context_token = hf_processor .video_token
890
+ tokens_per_frame ,
891
+ video_context_token = hf_processor .video_token ,
810
892
)
811
893
812
894
if self .info .supports_video :
@@ -901,7 +983,9 @@ def get_dummy_mm_data(
901
983
info = NanoNemotronVLProcessingInfo ,
902
984
dummy_inputs = NanoNemotronVLDummyInputsBuilder ,
903
985
)
904
- class NemotronH_Nano_VL_V2 (nn .Module , HasInnerState , IsHybrid , SupportsMultiModal ):
986
+ class NemotronH_Nano_VL_V2 (
987
+ nn .Module , HasInnerState , IsHybrid , SupportsMultiModal , SupportsMultiModalPruning
988
+ ):
905
989
@classmethod
906
990
def get_placeholder_str (cls , modality : str , i : int ) -> Optional [str ]:
907
991
if modality .startswith ("image" ):
@@ -913,7 +997,7 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
913
997
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
914
998
super ().__init__ ()
915
999
config = vllm_config .model_config .hf_config
916
-
1000
+ multimodal_config = vllm_config . model_config . multimodal_config
917
1001
image_size = config .force_image_size
918
1002
patch_size = config .patch_size
919
1003
self .patch_size = patch_size
@@ -924,7 +1008,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
924
1008
self .downsample_ratio = config .downsample_ratio
925
1009
self .ps_version = config .ps_version
926
1010
self .image_tag_type = config .image_tag_type
927
-
1011
+ self . video_pruning_rate = multimodal_config . video_pruning_rate
928
1012
self .language_model = init_vllm_registered_model (
929
1013
vllm_config = vllm_config ,
930
1014
hf_config = config .text_config ,
@@ -957,6 +1041,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
957
1041
self .img_context_token_id = None
958
1042
self .video_context_token_id = None
959
1043
self .config = config
1044
+ self .model_config = vllm_config .model_config
960
1045
961
1046
def pixel_shuffle (self , x , scale_factor = 0.5 ):
962
1047
n , w , h , c = x .size ()
@@ -1049,7 +1134,7 @@ def _parse_and_validate_image_input(
1049
1134
1050
1135
def _process_image_input (
1051
1136
self , image_input : NanoNemotronVLImageInputs
1052
- ) -> torch .Tensor :
1137
+ ) -> tuple [ torch .Tensor , ...] :
1053
1138
if image_input ["type" ] == "image_embeds" :
1054
1139
return image_input ["data" ]
1055
1140
@@ -1071,6 +1156,109 @@ def _process_image_input(
1071
1156
]
1072
1157
return image_embeds .split (image_feature_sizes )
1073
1158
1159
+ def _process_video_input (
1160
+ self , video_input : NanoNemotronVLVideoPixelInputs
1161
+ ) -> tuple [torch .Tensor , ...]:
1162
+ """Process video input and create final embeddings with video content
1163
+ and indicator tokens."""
1164
+ # Get video embeddings using the same processing as images
1165
+ video_embeddings = self ._process_image_input (video_input )
1166
+
1167
+ final_video_embeddings : tuple [torch .Tensor , ...] = ()
1168
+
1169
+ image_rows = image_cols = self .config .force_image_size
1170
+ downsample_ratio = self .config .downsample_ratio
1171
+ patch_size = self .config .patch_size
1172
+ rows = int (image_rows * downsample_ratio // patch_size )
1173
+ cols = int (image_cols * downsample_ratio // patch_size )
1174
+ video_pruning_rate = self .video_pruning_rate
1175
+
1176
+ # Calculate video feature dimensions (number of frames and
1177
+ # their feature size (AKA tokens per frame))
1178
+ # TODO: Maybe this can be optimized to avoid the loop?
1179
+ for i , single_video_embeddings in enumerate (video_embeddings ):
1180
+ num_frames = video_input ["num_patches" ][i ].item ()
1181
+ assert single_video_embeddings .shape [0 ] % num_frames == 0
1182
+
1183
+ if video_pruning_rate is not None and video_pruning_rate > 0.0 :
1184
+ # Start of EVS-specific code
1185
+ retention_mask = compute_retention_mask (
1186
+ single_video_embeddings ,
1187
+ video_size_thw = (num_frames , rows , cols ),
1188
+ spatial_merge_size = 1 ,
1189
+ q = video_pruning_rate ,
1190
+ )
1191
+
1192
+ # apply retention mask
1193
+ single_video_embeddings = single_video_embeddings [retention_mask ]
1194
+
1195
+ # calculate the actual number of retained tokens per frame
1196
+ retention_mask_thw = retention_mask .reshape (num_frames , rows , cols )
1197
+ num_tokens_per_frame = (
1198
+ retention_mask_thw .sum (dim = (1 , 2 )).long ().tolist ()
1199
+ )
1200
+ # End of EVS-specific code
1201
+ else :
1202
+ feature_size = single_video_embeddings .shape [0 ] // num_frames
1203
+ num_tokens_per_frame = [feature_size ] * num_frames
1204
+
1205
+ final_video_embeddings += (
1206
+ self ._create_final_video_embeddings (
1207
+ single_video_embeddings ,
1208
+ num_tokens_per_frame ,
1209
+ ),
1210
+ )
1211
+
1212
+ return final_video_embeddings
1213
+
1214
+ def _create_final_video_embeddings (
1215
+ self ,
1216
+ video_embeddings : torch .Tensor ,
1217
+ num_tokens_per_frame : list [int ],
1218
+ ) -> torch .Tensor :
1219
+ """Create final embeddings that combine video embeddings with
1220
+ text embeddings of indicator tokens.
1221
+
1222
+ These final embeddings contain:
1223
+ - Actual video embeddings in positions corresponding to video content
1224
+ - Text embeddings for indicator tokens (<img>, </img>, and
1225
+ frame separation text) in their respective positions
1226
+
1227
+ These embeddings will replace the placeholder embeddings to create
1228
+ input_embeds for the LLM.
1229
+ """
1230
+ device = video_embeddings .device
1231
+
1232
+ # Generate video replacement text and convert to token IDs
1233
+ video_repl_text = NanoNemotronVLProcessor .get_video_repl (
1234
+ num_tokens_per_frame ,
1235
+ IMG_CONTEXT ,
1236
+ ).full
1237
+
1238
+ tokenizer = cached_tokenizer_from_config (self .model_config )
1239
+ repl_token_ids = torch .tensor (
1240
+ _seq2tokens (tokenizer , video_repl_text ), device = device
1241
+ )
1242
+
1243
+ # Get embedding token IDs for image context
1244
+ embed_token_ids = torch .tensor (
1245
+ encode_tokens (tokenizer , IMG_CONTEXT ), device = device
1246
+ )
1247
+
1248
+ # Create mask for video embedding positions
1249
+ is_video_embed = torch .isin (repl_token_ids , embed_token_ids )
1250
+
1251
+ # Create final video embeddings, merging text embeddings for indicator
1252
+ # tokens with video embeddings
1253
+ text_embeddings = self .get_language_model ().get_input_embeddings (repl_token_ids )
1254
+ final_video_embeddings = _merge_multimodal_embeddings (
1255
+ inputs_embeds = text_embeddings ,
1256
+ multimodal_embeddings = video_embeddings ,
1257
+ is_multimodal = is_video_embed ,
1258
+ )
1259
+
1260
+ return final_video_embeddings
1261
+
1074
1262
def _parse_and_validate_video_input (
1075
1263
self , ** kwargs : object
1076
1264
) -> Optional [NanoNemotronVLVideoPixelInputs ]:
@@ -1152,7 +1340,7 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
1152
1340
multimodal_embeddings += vision_embeddings
1153
1341
if modality == "videos" :
1154
1342
video_input = modalities ["videos" ]
1155
- video_embeddings = self ._process_image_input (video_input )
1343
+ video_embeddings = self ._process_video_input (video_input )
1156
1344
multimodal_embeddings += video_embeddings
1157
1345
1158
1346
return multimodal_embeddings
0 commit comments