Skip to content

Commit 2f689df

Browse files
DarkLight1337amd-xiaoyu12
authored andcommitted
[Bugfix] Ensure correctness of HCXVision processing (vllm-project#23254)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
1 parent 41ac9b2 commit 2f689df

File tree

2 files changed

+56
-64
lines changed

2 files changed

+56
-64
lines changed

tests/models/multimodal/processing/test_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _test_processing_correctness(
102102
partial(random_video,
103103
rng,
104104
min_frames=2,
105-
max_frames=8,
105+
max_frames=16,
106106
min_wh=128,
107107
max_wh=256),
108108
"audio":

vllm/model_executor/models/hyperclovax_vision.py

Lines changed: 55 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,21 @@
5353
VIDEO_TOKEN: str = "<|_unuse_missing_100270|>"
5454

5555

56+
# Based on combine_frames_into_images in
57+
# https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py
58+
def get_num_combined_frames(
59+
num_frames: int,
60+
max_grid_shape: tuple[int, int] = (3, 3),
61+
) -> int:
62+
max_num_grids = max_grid_shape[0] * max_grid_shape[1]
63+
64+
# Calculate the number of canvases needed.
65+
num_canvases = num_frames // max_num_grids
66+
leftover_frames = num_frames % max_num_grids
67+
68+
return num_canvases + (leftover_frames > 0)
69+
70+
5671
class HCXVisionMultimodalPixelInputs(TypedDict):
5772
type: Literal["pixel_values"]
5873
pixel_values_images: list[torch.Tensor]
@@ -172,23 +187,20 @@ def _call_hf_processor(
172187
def replace_multimodal_token(
173188
token_ids: torch.Tensor,
174189
target_token: int,
175-
repeats: list,
190+
repeats: list[int],
176191
):
177-
output = list()
192+
output = list[int]()
178193
_repeats_idx = 0
179194
for token_id in token_ids:
180195
if token_id == target_token:
181-
output += [
182-
token_id.item(),
183-
] * repeats[_repeats_idx]
196+
output += [token_id.item()] * repeats[_repeats_idx]
184197
_repeats_idx += 1
185198
else:
186-
output += [
187-
token_id.item(),
188-
]
199+
output += [token_id.item()]
200+
189201
return torch.tensor(output, device=token_ids.device)
190202

191-
for video_idx, video_arr in enumerate(mm_data.get("videos", list())):
203+
for video_idx, video_arr in enumerate(mm_data.get("videos", [])):
192204
if video_arr.dtype == np.uint8:
193205
continue
194206
mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
@@ -205,88 +217,68 @@ def replace_multimodal_token(
205217
if len(mm_data) > 0:
206218
# batchify input as a single item
207219
images = mm_data.get("images", None)
208-
num_images = 0
209-
if images is not None:
210-
num_images = len(images)
211-
images = [
212-
images,
213-
] # batchify
214-
215-
videos = mm_data.get("videos",
216-
None) # list of video in single conversation
217-
num_videos = 0
218-
if videos is not None:
219-
num_videos = len(videos)
220-
videos = [
221-
videos,
222-
] # batchify
220+
batched_images = None if images is None else [images]
221+
222+
# list of video in single conversation
223+
videos = mm_data.get("videos", None)
224+
batched_videos = None if videos is None else [videos]
223225

224226
_processed_outputs = self.info.ctx.call_hf_processor(
225227
hf_processor=self.info.get_hf_processor(**mm_kwargs),
226228
data=dict(
227229
text=None,
228-
images=images,
229-
videos=videos,
230+
images=batched_images,
231+
videos=batched_videos,
230232
),
231233
) # mm-only
232234

233235
for k, v in _processed_outputs.items():
234-
if len(v) < 1:
235-
continue
236-
elif k.endswith("_images"):
237-
# list of list of 4D tensor -> list of 4D tensor
236+
if isinstance(v, list) and len(v) > 0:
237+
assert len(v) == 1
238238
_processed_outputs[k] = v[0]
239-
elif k.endswith("_videos"):
240-
# list of list of 4D tensor -> list of 4D tensor
241-
v = v[0]
242-
if k == "pixel_values_videos":
243-
v = torch.cat(v, dim=0)
244-
_c, _w, _h = v.shape[-3:]
245-
v = v.reshape(num_videos, -1, _c, _w, _h)
246-
v = list(torch.unbind(v, dim=0))
247-
_processed_outputs[k] = v
248-
249-
if num_images > 0:
239+
240+
if images:
250241
tokenizer = self.info.get_tokenizer()
242+
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
251243
processed_outputs["input_ids"] = torch.stack([
252244
replace_multimodal_token(
253245
token_ids=_input_ids,
254-
target_token=tokenizer.convert_tokens_to_ids(
255-
IMAGE_TOKEN),
246+
target_token=image_token_id,
256247
repeats=_processed_outputs[
257248
"vision_query_lengths_images"],
258249
) for _input_ids in processed_outputs["input_ids"]
259250
],
260251
dim=0)
261252

262-
if num_videos > 0:
253+
if videos:
254+
_num_per_videos = [
255+
get_num_combined_frames(len(video)) for video in videos
256+
]
257+
_processed_outputs["pixel_values_videos"] = [
258+
_processed_outputs["pixel_values_videos"]
259+
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
260+
for _i in range(len(videos))
261+
]
262+
_processed_outputs["vision_query_lengths_videos"] = [
263+
_processed_outputs["vision_query_lengths_videos"]
264+
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
265+
for _i in range(len(videos))
266+
]
267+
263268
tokenizer = self.info.get_tokenizer()
269+
video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN)
264270
processed_outputs["input_ids"] = torch.stack([
265271
replace_multimodal_token(
266272
token_ids=_input_ids,
267-
target_token=tokenizer.convert_tokens_to_ids(
268-
VIDEO_TOKEN),
269-
repeats=_processed_outputs[
270-
"vision_query_lengths_videos"],
273+
target_token=video_token_id,
274+
repeats=[
275+
sum(lens) for lens in
276+
_processed_outputs["vision_query_lengths_videos"]
277+
],
271278
) for _input_ids in processed_outputs["input_ids"]
272279
],
273280
dim=0)
274281

275-
_ratios = [
276-
len(_pixel_values) for _pixel_values in
277-
_processed_outputs["pixel_values_videos"]
278-
]
279-
_num_per_videos = [
280-
int(_e / sum(_ratios) *
281-
len(_processed_outputs["vision_query_lengths_videos"]))
282-
for _e in _ratios
283-
]
284-
_processed_outputs["vision_query_lengths_videos"] = [
285-
_processed_outputs["vision_query_lengths_videos"]
286-
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
287-
for _i in range(0, num_videos)
288-
]
289-
290282
processed_outputs.update(_processed_outputs)
291283

292284
return processed_outputs

0 commit comments

Comments
 (0)