Skip to content

Commit b8f603c

Browse files
tomeras91gemini-code-assist[bot]BloodAxe
authored
[Model] EVS support for nano_nemotron_vl (#26269)
Signed-off-by: Tomer Asida <[email protected]> Signed-off-by: tomeras91 <[email protected]> Signed-off-by: Eugene Khvedchenia <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Eugene Khvedchenia <[email protected]>
1 parent fc67969 commit b8f603c

File tree

3 files changed

+224
-31
lines changed

3 files changed

+224
-31
lines changed

vllm/model_executor/models/nano_nemotron_vl.py

Lines changed: 207 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
IsHybrid,
3131
MultiModalEmbeddings,
3232
SupportsMultiModal,
33+
SupportsMultiModalPruning,
3334
)
3435
from vllm.model_executor.models.internvl import (
3536
calculate_internvl_targets,
@@ -44,6 +45,10 @@
4445
maybe_prefix,
4546
)
4647
from vllm.multimodal import MULTIMODAL_REGISTRY
48+
from vllm.multimodal.evs import (
49+
compute_retained_tokens_count,
50+
compute_retention_mask,
51+
)
4752
from vllm.multimodal.inputs import (
4853
MultiModalDataDict,
4954
MultiModalFieldConfig,
@@ -62,13 +67,20 @@
6267
PromptReplacement,
6368
PromptUpdate,
6469
PromptUpdateDetails,
70+
_seq2tokens,
6571
)
6672
from vllm.multimodal.profiling import BaseDummyInputsBuilder
6773
from vllm.sequence import IntermediateTensors
6874
from vllm.transformers_utils.configs.radio import RadioConfig
69-
from vllm.transformers_utils.tokenizer import AnyTokenizer
75+
from vllm.transformers_utils.tokenizer import (
76+
AnyTokenizer,
77+
cached_tokenizer_from_config,
78+
encode_tokens,
79+
)
7080
from vllm.utils.tensor_schema import TensorSchema, TensorShape
7181

82+
from .utils import _merge_multimodal_embeddings
83+
7284
# Configure PIL to handle large images without warnings
7385
# This prevents DecompressionBombWarning for legitimate large images
7486
Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
@@ -382,6 +394,7 @@ def __init__(
382394
max_dynamic_patch: Optional[int] = None,
383395
dynamic_image_size: Optional[bool] = None,
384396
video_token: Optional[str] = None,
397+
video_pruning_rate: Optional[float] = None,
385398
) -> None:
386399
super().__init__(
387400
config=config,
@@ -392,6 +405,7 @@ def __init__(
392405
)
393406
# add extra video token for video processing
394407
self.video_token = video_token
408+
self.video_pruning_rate = video_pruning_rate
395409

396410
@property
397411
def supports_video(self) -> bool:
@@ -446,12 +460,38 @@ def _preprocess_video(
446460
),
447461
}
448462

463+
image_size: int = self.config.force_image_size
464+
patch_size: int = self.config.patch_size
465+
downsample_ratio = self.config.downsample_ratio
466+
tokens_in_single_frame = int(
467+
(image_size * image_size // patch_size**2) * (downsample_ratio**2)
468+
)
469+
449470
for pixel_values in pixel_values_lst_video:
450-
num_patches = pixel_values.shape[0]
471+
num_frames = pixel_values.shape[0]
472+
473+
if (
474+
self.video_pruning_rate is not None
475+
and self.video_pruning_rate > 0.0
476+
):
477+
# Start of EVS-specific code
478+
num_tokens = compute_retained_tokens_count(
479+
tokens_per_frame=tokens_in_single_frame,
480+
num_frames=num_frames,
481+
q=self.video_pruning_rate,
482+
)
483+
484+
# Here we just need placeholders that won't actually be replaced -
485+
# we just need to make sure the total number of tokens is correct
486+
# assign all tokens to the first frame
487+
tokens_per_frame = [num_tokens] + [0] * (num_frames - 1)
488+
489+
# End of EVS-specific code
490+
else:
491+
tokens_per_frame = [tokens_in_single_frame] * num_frames
492+
493+
video_repl = self.get_video_repl(tokens_per_frame, self.video_token)
451494

452-
video_repl = self.get_video_repl(
453-
self.num_image_token, num_patches, self.video_token
454-
)
455495
text = [t.replace("<video>", video_repl.full, 1) for t in text]
456496
return text, video_inputs
457497

@@ -501,20 +541,40 @@ def get_image_repl(
501541

502542
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
503543

544+
@classmethod
504545
def get_video_repl(
505-
self,
506-
feature_size: int,
507-
num_patches: Optional[int] = None,
546+
cls,
547+
tokens_per_frame: list[int],
508548
video_context_token: str = IMG_CONTEXT,
509549
) -> PromptUpdateDetails[str]:
510-
repl_features = video_context_token * self.num_image_token
511-
repl_features_with_sep = IMG_START + repl_features + IMG_END
512-
# num_patches is equal to num_frames
550+
"""
551+
Build prompt replacement for a video.
552+
The replacement returned is not actually used to replace the placeholder
553+
tokens - it's just used to make sure we allocate the correct number
554+
of tokens.
555+
Actual replacement is done in get_multimodal_embeddings of
556+
NemotronH_Nano_VL_V2
557+
(specifically in _process_video_input -> _create_final_video_embeddings).
558+
There, we create the final embeddings with text embeddings for indicator tokens
559+
and video embeddings for video tokens.
560+
This is a single function that handles all cases - non EVS, EVS dummy, EVS real.
561+
The differentiation is done via tokens_per_frame parameter.
562+
- non EVS case - constant value same value across all frames
563+
- EVS dummy - Doesn't matter how tokens are distributed between frames - just
564+
make sure the total number of tokens is correct.
565+
- EVS real (called from get_real_video_repl_for_evs) - different value per frame
566+
Args:
567+
tokens_per_frame (list[int]): number of tokens per frame
568+
video_context_token (str): the token to use for the video context
569+
"""
513570
repl_full = "".join(
514-
[f"Frame{i + 1}: {repl_features_with_sep}" for i in range(num_patches)]
571+
[
572+
f"Frame{i + 1}: {IMG_START}{video_context_token * num_tokens}{IMG_END}"
573+
for i, num_tokens in enumerate(tokens_per_frame)
574+
]
515575
)
516576

517-
return PromptUpdateDetails.select_text(repl_full, video_context_token)
577+
return PromptUpdateDetails.from_seq(repl_full)
518578

519579

520580
class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
@@ -605,6 +665,9 @@ def get_supported_mm_limits(self):
605665
def get_video_token(self) -> Optional[str]:
606666
return IMG_CONTEXT
607667

668+
def get_video_pruning_rate(self) -> Optional[float]:
669+
return self.ctx.get_mm_config().video_pruning_rate
670+
608671
def get_num_frames_with_most_features(
609672
self,
610673
seq_len: int,
@@ -628,6 +691,7 @@ def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor:
628691
config=self.get_hf_config(),
629692
tokenizer=self.get_tokenizer(),
630693
video_token=self.get_video_token(),
694+
video_pruning_rate=self.get_video_pruning_rate(),
631695
**kwargs,
632696
)
633697

@@ -805,8 +869,26 @@ def get_video_replacement_internvl(item_idx: int):
805869
if num_patches is not None:
806870
assert isinstance(num_patches, int)
807871

872+
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
873+
if video_pruning_rate is not None and video_pruning_rate > 0.0:
874+
# Start of EVS-specific code
875+
num_tokens = compute_retained_tokens_count(
876+
tokens_per_frame=feature_size,
877+
num_frames=num_patches,
878+
q=video_pruning_rate,
879+
)
880+
# Here we just need placeholders that won't actually be replaced -
881+
# we just need to make sure the total number of tokens is correct
882+
# assign all tokens to the first frame
883+
tokens_per_frame = [num_tokens] + [0] * (num_patches - 1)
884+
885+
# End of EVS-specific code
886+
else:
887+
tokens_per_frame = [feature_size] * num_patches
888+
808889
return hf_processor.get_video_repl(
809-
feature_size, num_patches, video_context_token=hf_processor.video_token
890+
tokens_per_frame,
891+
video_context_token=hf_processor.video_token,
810892
)
811893

812894
if self.info.supports_video:
@@ -901,7 +983,9 @@ def get_dummy_mm_data(
901983
info=NanoNemotronVLProcessingInfo,
902984
dummy_inputs=NanoNemotronVLDummyInputsBuilder,
903985
)
904-
class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModal):
986+
class NemotronH_Nano_VL_V2(
987+
nn.Module, HasInnerState, IsHybrid, SupportsMultiModal, SupportsMultiModalPruning
988+
):
905989
@classmethod
906990
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
907991
if modality.startswith("image"):
@@ -913,7 +997,7 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
913997
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
914998
super().__init__()
915999
config = vllm_config.model_config.hf_config
916-
1000+
multimodal_config = vllm_config.model_config.multimodal_config
9171001
image_size = config.force_image_size
9181002
patch_size = config.patch_size
9191003
self.patch_size = patch_size
@@ -924,7 +1008,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
9241008
self.downsample_ratio = config.downsample_ratio
9251009
self.ps_version = config.ps_version
9261010
self.image_tag_type = config.image_tag_type
927-
1011+
self.video_pruning_rate = multimodal_config.video_pruning_rate
9281012
self.language_model = init_vllm_registered_model(
9291013
vllm_config=vllm_config,
9301014
hf_config=config.text_config,
@@ -957,6 +1041,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
9571041
self.img_context_token_id = None
9581042
self.video_context_token_id = None
9591043
self.config = config
1044+
self.model_config = vllm_config.model_config
9601045

9611046
def pixel_shuffle(self, x, scale_factor=0.5):
9621047
n, w, h, c = x.size()
@@ -1049,7 +1134,7 @@ def _parse_and_validate_image_input(
10491134

10501135
def _process_image_input(
10511136
self, image_input: NanoNemotronVLImageInputs
1052-
) -> torch.Tensor:
1137+
) -> tuple[torch.Tensor, ...]:
10531138
if image_input["type"] == "image_embeds":
10541139
return image_input["data"]
10551140

@@ -1071,6 +1156,109 @@ def _process_image_input(
10711156
]
10721157
return image_embeds.split(image_feature_sizes)
10731158

1159+
def _process_video_input(
1160+
self, video_input: NanoNemotronVLVideoPixelInputs
1161+
) -> tuple[torch.Tensor, ...]:
1162+
"""Process video input and create final embeddings with video content
1163+
and indicator tokens."""
1164+
# Get video embeddings using the same processing as images
1165+
video_embeddings = self._process_image_input(video_input)
1166+
1167+
final_video_embeddings: tuple[torch.Tensor, ...] = ()
1168+
1169+
image_rows = image_cols = self.config.force_image_size
1170+
downsample_ratio = self.config.downsample_ratio
1171+
patch_size = self.config.patch_size
1172+
rows = int(image_rows * downsample_ratio // patch_size)
1173+
cols = int(image_cols * downsample_ratio // patch_size)
1174+
video_pruning_rate = self.video_pruning_rate
1175+
1176+
# Calculate video feature dimensions (number of frames and
1177+
# their feature size (AKA tokens per frame))
1178+
# TODO: Maybe this can be optimized to avoid the loop?
1179+
for i, single_video_embeddings in enumerate(video_embeddings):
1180+
num_frames = video_input["num_patches"][i].item()
1181+
assert single_video_embeddings.shape[0] % num_frames == 0
1182+
1183+
if video_pruning_rate is not None and video_pruning_rate > 0.0:
1184+
# Start of EVS-specific code
1185+
retention_mask = compute_retention_mask(
1186+
single_video_embeddings,
1187+
video_size_thw=(num_frames, rows, cols),
1188+
spatial_merge_size=1,
1189+
q=video_pruning_rate,
1190+
)
1191+
1192+
# apply retention mask
1193+
single_video_embeddings = single_video_embeddings[retention_mask]
1194+
1195+
# calculate the actual number of retained tokens per frame
1196+
retention_mask_thw = retention_mask.reshape(num_frames, rows, cols)
1197+
num_tokens_per_frame = (
1198+
retention_mask_thw.sum(dim=(1, 2)).long().tolist()
1199+
)
1200+
# End of EVS-specific code
1201+
else:
1202+
feature_size = single_video_embeddings.shape[0] // num_frames
1203+
num_tokens_per_frame = [feature_size] * num_frames
1204+
1205+
final_video_embeddings += (
1206+
self._create_final_video_embeddings(
1207+
single_video_embeddings,
1208+
num_tokens_per_frame,
1209+
),
1210+
)
1211+
1212+
return final_video_embeddings
1213+
1214+
def _create_final_video_embeddings(
1215+
self,
1216+
video_embeddings: torch.Tensor,
1217+
num_tokens_per_frame: list[int],
1218+
) -> torch.Tensor:
1219+
"""Create final embeddings that combine video embeddings with
1220+
text embeddings of indicator tokens.
1221+
1222+
These final embeddings contain:
1223+
- Actual video embeddings in positions corresponding to video content
1224+
- Text embeddings for indicator tokens (<img>, </img>, and
1225+
frame separation text) in their respective positions
1226+
1227+
These embeddings will replace the placeholder embeddings to create
1228+
input_embeds for the LLM.
1229+
"""
1230+
device = video_embeddings.device
1231+
1232+
# Generate video replacement text and convert to token IDs
1233+
video_repl_text = NanoNemotronVLProcessor.get_video_repl(
1234+
num_tokens_per_frame,
1235+
IMG_CONTEXT,
1236+
).full
1237+
1238+
tokenizer = cached_tokenizer_from_config(self.model_config)
1239+
repl_token_ids = torch.tensor(
1240+
_seq2tokens(tokenizer, video_repl_text), device=device
1241+
)
1242+
1243+
# Get embedding token IDs for image context
1244+
embed_token_ids = torch.tensor(
1245+
encode_tokens(tokenizer, IMG_CONTEXT), device=device
1246+
)
1247+
1248+
# Create mask for video embedding positions
1249+
is_video_embed = torch.isin(repl_token_ids, embed_token_ids)
1250+
1251+
# Create final video embeddings, merging text embeddings for indicator
1252+
# tokens with video embeddings
1253+
text_embeddings = self.get_language_model().get_input_embeddings(repl_token_ids)
1254+
final_video_embeddings = _merge_multimodal_embeddings(
1255+
inputs_embeds=text_embeddings,
1256+
multimodal_embeddings=video_embeddings,
1257+
is_multimodal=is_video_embed,
1258+
)
1259+
1260+
return final_video_embeddings
1261+
10741262
def _parse_and_validate_video_input(
10751263
self, **kwargs: object
10761264
) -> Optional[NanoNemotronVLVideoPixelInputs]:
@@ -1152,7 +1340,7 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
11521340
multimodal_embeddings += vision_embeddings
11531341
if modality == "videos":
11541342
video_input = modalities["videos"]
1155-
video_embeddings = self._process_image_input(video_input)
1343+
video_embeddings = self._process_video_input(video_input)
11561344
multimodal_embeddings += video_embeddings
11571345

11581346
return multimodal_embeddings

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,9 +1017,13 @@ def get_replacement_qwen2vl(item_idx: int, modality: str):
10171017
and video_pruning_rate is not None
10181018
and video_pruning_rate > 0.0
10191019
):
1020+
T, H, W = map(int, grid_thw)
1021+
tokens_per_frame = (H // image_processor.merge_size) * (
1022+
W // image_processor.merge_size
1023+
)
10201024
num_tokens = compute_retained_tokens_count(
1021-
grid_thw,
1022-
image_processor.merge_size,
1025+
tokens_per_frame,
1026+
T,
10231027
video_pruning_rate,
10241028
)
10251029
# End of EVS-specific code

0 commit comments

Comments
 (0)