Skip to content

Commit a01e001

Browse files
[Bugfix] Fix Nemotron VL image processing (#22739)
Co-authored-by: ducviet00-h2 <[email protected]>
1 parent 9e7e5ba commit a01e001

File tree

2 files changed

+190
-4
lines changed

2 files changed

+190
-4
lines changed

tests/models/multimodal/processing/test_nemotron_vl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ def _get_expected_num_patches(
2323
min_num: int,
2424
max_num: int,
2525
):
26-
from vllm.model_executor.models.internvl import (
27-
calculate_internvl_targets, get_internvl_target_ratios)
26+
from vllm.model_executor.models.nemotron_vl import (
27+
calculate_nemotron_vl_targets, get_nemotron_vl_target_ratios)
2828

2929
width, height = image.size
3030

31-
blocks, _, _ = calculate_internvl_targets(
31+
blocks, _, _ = calculate_nemotron_vl_targets(
3232
orig_width=width,
3333
orig_height=height,
34-
target_ratios=get_internvl_target_ratios(
34+
target_ratios=get_nemotron_vl_target_ratios(
3535
min_num,
3636
max_num,
3737
),

vllm/model_executor/models/nemotron_vl.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import torch
1515
import torch.nn as nn
16+
import torchvision.transforms as T
1617
from PIL import Image
1718
from transformers import AutoModel, PretrainedConfig
1819
from transformers.image_processing_utils_fast import BaseImageProcessorFast
@@ -27,6 +28,7 @@
2728
from vllm.model_executor.models.module_mapping import MultiModelKeys
2829
from vllm.model_executor.sampling_metadata import SamplingMetadata
2930
from vllm.multimodal import MULTIMODAL_REGISTRY
31+
from vllm.multimodal.image import convert_image_mode
3032
from vllm.multimodal.inputs import NestedTensors
3133
from vllm.multimodal.processing import PromptUpdateDetails
3234
from vllm.sequence import IntermediateTensors
@@ -44,6 +46,146 @@
4446
IMG_CONTEXT = '<image>'
4547

4648

49+
def build_transform(input_size: int):
50+
return T.Compose([
51+
T.Lambda(lambda img: convert_image_mode(img, 'RGB')),
52+
T.Resize((input_size, input_size),
53+
interpolation=T.InterpolationMode.BICUBIC),
54+
T.ToTensor(),
55+
])
56+
57+
58+
# adapted from https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1
59+
def find_closest_aspect_ratio(
60+
aspect_ratio: float,
61+
target_ratios: list[tuple[int, int]],
62+
*,
63+
width: int,
64+
height: int,
65+
image_size: int,
66+
) -> tuple[int, int]:
67+
best_factor = float('-inf')
68+
best_ratio = (1, 1)
69+
area = width * height
70+
71+
for rw, rh in target_ratios:
72+
target_aspect_ratio = rw / rh
73+
size_factor = min((rw * rh * image_size * image_size) / area, 0.6)
74+
ratio_closeness = min(target_aspect_ratio / aspect_ratio,
75+
aspect_ratio / target_aspect_ratio)
76+
factor = size_factor * ratio_closeness
77+
78+
if factor > best_factor:
79+
best_factor = factor
80+
best_ratio = (rw, rh)
81+
82+
return best_ratio
83+
84+
85+
def calculate_nemotron_vl_targets(
86+
*,
87+
orig_width: int,
88+
orig_height: int,
89+
target_ratios: list[tuple[int, int]],
90+
image_size: int,
91+
use_thumbnail: bool,
92+
) -> tuple[int, int, int]:
93+
aspect_ratio = orig_width / orig_height
94+
95+
# find the closest aspect ratio to the target
96+
target_aspect_ratio = find_closest_aspect_ratio(
97+
aspect_ratio,
98+
target_ratios,
99+
width=orig_width,
100+
height=orig_height,
101+
image_size=image_size,
102+
)
103+
104+
# calculate the target width and height
105+
target_width = image_size * target_aspect_ratio[0]
106+
target_height = image_size * target_aspect_ratio[1]
107+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
108+
109+
# add thumbnail image if num_blocks != 1
110+
if use_thumbnail and blocks != 1:
111+
blocks += 1
112+
113+
return blocks, target_width, target_height
114+
115+
116+
def dynamic_preprocess_nemotron_vl(
117+
image: Image.Image,
118+
*,
119+
target_ratios: list[tuple[int, int]],
120+
image_size: int,
121+
use_thumbnail: bool,
122+
) -> list[Image.Image]:
123+
orig_width, orig_height = image.size
124+
125+
# calculate the number of blocks without thumbnail
126+
blocks, target_width, target_height = calculate_nemotron_vl_targets(
127+
orig_width=orig_width,
128+
orig_height=orig_height,
129+
target_ratios=target_ratios,
130+
image_size=image_size,
131+
use_thumbnail=False,
132+
)
133+
134+
# resize the image
135+
resized_img = image.resize((target_width, target_height))
136+
processed_images = []
137+
for i in range(blocks):
138+
box = ((i % (target_width // image_size)) * image_size,
139+
(i // (target_width // image_size)) * image_size,
140+
((i % (target_width // image_size)) + 1) * image_size,
141+
((i // (target_width // image_size)) + 1) * image_size)
142+
# split the image
143+
split_img = resized_img.crop(box)
144+
processed_images.append(split_img)
145+
146+
assert len(processed_images) == blocks
147+
148+
if use_thumbnail and len(processed_images) != 1:
149+
thumbnail_img = image.resize((image_size, image_size))
150+
processed_images.append(thumbnail_img)
151+
152+
return processed_images
153+
154+
155+
def get_nemotron_vl_target_ratios(
156+
min_num: int,
157+
max_num: int,
158+
) -> list[tuple[int, int]]:
159+
target_ratios = {(i, j)
160+
for n in range(min_num, max_num + 1)
161+
for i in range(1, n + 1)
162+
for j in range(1, n + 1) if min_num <= i * j <= max_num}
163+
return sorted(target_ratios, key=lambda x: x[0] * x[1])
164+
165+
166+
def image_to_pixel_values_nemotron_vl(
167+
image: Image.Image,
168+
*,
169+
input_size: int,
170+
min_num: int,
171+
max_num: int,
172+
use_thumbnail: bool,
173+
) -> torch.Tensor:
174+
target_ratios = get_nemotron_vl_target_ratios(min_num, max_num)
175+
176+
transform = build_transform(input_size=input_size)
177+
178+
images = dynamic_preprocess_nemotron_vl(
179+
image,
180+
target_ratios=target_ratios,
181+
image_size=input_size,
182+
use_thumbnail=use_thumbnail,
183+
)
184+
185+
pixel_values = torch.stack([transform(image) for image in images])
186+
return pixel_values
187+
188+
47189
class NemotronVLProcessor(InternVLProcessor):
48190

49191
def __init__(
@@ -87,6 +229,50 @@ def __init__(
87229
def image_token_id(self) -> int:
88230
return self.tokenizer.get_vocab()[IMG_CONTEXT]
89231

232+
def get_num_image_tokens(
233+
self,
234+
*,
235+
image_width: int,
236+
image_height: int,
237+
) -> int:
238+
target_ratios = self.resolve_target_ratios(
239+
use_thumbnail=False, # Applied in calculate_targets
240+
)
241+
242+
num_patches, _, _ = calculate_nemotron_vl_targets(
243+
orig_width=image_width,
244+
orig_height=image_height,
245+
image_size=self.image_size,
246+
target_ratios=target_ratios,
247+
use_thumbnail=self.use_thumbnail,
248+
)
249+
250+
return num_patches * self.num_image_token
251+
252+
def _images_to_pixel_values_lst(
253+
self,
254+
images: list[Image.Image],
255+
min_dynamic_patch: Optional[int] = None,
256+
max_dynamic_patch: Optional[int] = None,
257+
dynamic_image_size: Optional[bool] = None,
258+
) -> list[torch.Tensor]:
259+
min_num, max_num = self.resolve_min_max_num(
260+
min_dynamic_patch=min_dynamic_patch,
261+
max_dynamic_patch=max_dynamic_patch,
262+
dynamic_image_size=dynamic_image_size,
263+
use_thumbnail=False, # Applied in image_to_pixel_values
264+
)
265+
266+
return [
267+
image_to_pixel_values_nemotron_vl(
268+
image,
269+
input_size=self.image_size,
270+
min_num=min_num,
271+
max_num=max_num,
272+
use_thumbnail=self.use_thumbnail,
273+
) for image in images
274+
]
275+
90276
def _preprocess_image(
91277
self,
92278
text: list[str],

0 commit comments

Comments
 (0)