Skip to content

Commit 18e519e

Browse files
authored
[Bugfix] Fix ndarray video color from VideoAsset (#21064)
Signed-off-by: Isotr0py <[email protected]>
1 parent 1eaff27 commit 18e519e

File tree

3 files changed

+130
-28
lines changed

3 files changed

+130
-28
lines changed

tests/multimodal/test_video.py

Lines changed: 80 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import tempfile
5+
from pathlib import Path
6+
37
import numpy as np
48
import numpy.typing as npt
59
import pytest
10+
from PIL import Image
611

7-
from vllm import envs
12+
from vllm.assets.base import get_vllm_public_assets
13+
from vllm.assets.video import video_to_ndarrays, video_to_pil_images_list
814
from vllm.multimodal.image import ImageMediaIO
915
from vllm.multimodal.video import (VIDEO_LOADER_REGISTRY, VideoLoader,
1016
VideoMediaIO)
1117

18+
from .utils import cosine_similarity, create_video_from_image, normalize_image
19+
1220
NUM_FRAMES = 10
1321
FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
1422
FAKE_OUTPUT_2 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
@@ -59,30 +67,79 @@ def load_bytes(cls,
5967
return FAKE_OUTPUT_2
6068

6169

62-
def test_video_media_io_kwargs():
63-
envs.VLLM_VIDEO_LOADER_BACKEND = "assert_10_frames_1_fps"
64-
imageio = ImageMediaIO()
70+
def test_video_media_io_kwargs(monkeypatch: pytest.MonkeyPatch):
71+
with monkeypatch.context() as m:
72+
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "assert_10_frames_1_fps")
73+
imageio = ImageMediaIO()
6574

66-
# Verify that different args pass/fail assertions as expected.
67-
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 1.0})
68-
_ = videoio.load_bytes(b"test")
69-
70-
videoio = VideoMediaIO(
71-
imageio, **{
72-
"num_frames": 10,
73-
"fps": 1.0,
74-
"not_used": "not_used"
75-
})
76-
_ = videoio.load_bytes(b"test")
77-
78-
with pytest.raises(AssertionError, match="bad num_frames"):
79-
videoio = VideoMediaIO(imageio, **{})
75+
# Verify that different args pass/fail assertions as expected.
76+
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 1.0})
8077
_ = videoio.load_bytes(b"test")
8178

82-
with pytest.raises(AssertionError, match="bad num_frames"):
83-
videoio = VideoMediaIO(imageio, **{"num_frames": 9, "fps": 1.0})
79+
videoio = VideoMediaIO(
80+
imageio, **{
81+
"num_frames": 10,
82+
"fps": 1.0,
83+
"not_used": "not_used"
84+
})
8485
_ = videoio.load_bytes(b"test")
8586

86-
with pytest.raises(AssertionError, match="bad fps"):
87-
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 2.0})
88-
_ = videoio.load_bytes(b"test")
87+
with pytest.raises(AssertionError, match="bad num_frames"):
88+
videoio = VideoMediaIO(imageio, **{})
89+
_ = videoio.load_bytes(b"test")
90+
91+
with pytest.raises(AssertionError, match="bad num_frames"):
92+
videoio = VideoMediaIO(imageio, **{"num_frames": 9, "fps": 1.0})
93+
_ = videoio.load_bytes(b"test")
94+
95+
with pytest.raises(AssertionError, match="bad fps"):
96+
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 2.0})
97+
_ = videoio.load_bytes(b"test")
98+
99+
100+
@pytest.mark.parametrize("is_color", [True, False])
101+
@pytest.mark.parametrize("fourcc, ext", [("mp4v", "mp4"), ("XVID", "avi")])
102+
def test_opencv_video_io_colorspace(is_color: bool, fourcc: str, ext: str):
103+
"""
104+
Test all functions that use OpenCV for video I/O return RGB format.
105+
Both RGB and grayscale videos are tested.
106+
"""
107+
image_path = get_vllm_public_assets(filename="stop_sign.jpg",
108+
s3_prefix="vision_model_images")
109+
image = Image.open(image_path)
110+
with tempfile.TemporaryDirectory() as tmpdir:
111+
if not is_color:
112+
image_path = f"{tmpdir}/test_grayscale_image.png"
113+
image = image.convert("L")
114+
image.save(image_path)
115+
# Convert to gray RGB for comparison
116+
image = image.convert("RGB")
117+
video_path = f"{tmpdir}/test_RGB_video.{ext}"
118+
create_video_from_image(
119+
image_path,
120+
video_path,
121+
num_frames=2,
122+
is_color=is_color,
123+
fourcc=fourcc,
124+
)
125+
126+
frames = video_to_ndarrays(video_path)
127+
for frame in frames:
128+
sim = cosine_similarity(normalize_image(np.array(frame)),
129+
normalize_image(np.array(image)))
130+
assert np.sum(np.isnan(sim)) / sim.size < 0.001
131+
assert np.nanmean(sim) > 0.99
132+
133+
pil_frames = video_to_pil_images_list(video_path)
134+
for frame in pil_frames:
135+
sim = cosine_similarity(normalize_image(np.array(frame)),
136+
normalize_image(np.array(image)))
137+
assert np.sum(np.isnan(sim)) / sim.size < 0.001
138+
assert np.nanmean(sim) > 0.99
139+
140+
io_frames, _ = VideoMediaIO(ImageMediaIO()).load_file(Path(video_path))
141+
for frame in io_frames:
142+
sim = cosine_similarity(normalize_image(np.array(frame)),
143+
normalize_image(np.array(image)))
144+
assert np.sum(np.isnan(sim)) / sim.size < 0.001
145+
assert np.nanmean(sim) > 0.99

tests/multimodal/utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import cv2
45
import numpy as np
6+
import numpy.typing as npt
57
from PIL import Image
68

79

@@ -31,3 +33,47 @@ def random_audio(
3133
):
3234
audio_len = rng.randint(min_len, max_len)
3335
return rng.rand(audio_len), sr
36+
37+
38+
def create_video_from_image(
39+
image_path: str,
40+
video_path: str,
41+
num_frames: int = 10,
42+
fps: float = 1.0,
43+
is_color: bool = True,
44+
fourcc: str = "mp4v",
45+
):
46+
image = cv2.imread(image_path)
47+
if not is_color:
48+
# Convert to grayscale if is_color is False
49+
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
50+
height, width = image.shape
51+
else:
52+
height, width, _ = image.shape
53+
54+
video_writer = cv2.VideoWriter(
55+
video_path,
56+
cv2.VideoWriter_fourcc(*fourcc),
57+
fps,
58+
(width, height),
59+
isColor=is_color,
60+
)
61+
62+
for _ in range(num_frames):
63+
video_writer.write(image)
64+
65+
video_writer.release()
66+
return video_path
67+
68+
69+
def cosine_similarity(A: npt.NDArray,
70+
B: npt.NDArray,
71+
axis: int = -1) -> npt.NDArray:
72+
"""Compute cosine similarity between two vectors."""
73+
return (np.sum(A * B, axis=axis) /
74+
(np.linalg.norm(A, axis=axis) * np.linalg.norm(B, axis=axis)))
75+
76+
77+
def normalize_image(image: npt.NDArray) -> npt.NDArray:
78+
"""Normalize image to [0, 1] range."""
79+
return image.astype(np.float32) / 255.0

vllm/assets/video.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
5959
if idx in frame_indices: # only decompress needed
6060
ret, frame = cap.retrieve()
6161
if ret:
62-
frames.append(frame)
62+
# OpenCV uses BGR format, we need to convert it to RGB
63+
# for PIL and transformers compatibility
64+
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
6365

6466
frames = np.stack(frames)
6567
if len(frames) < num_frames:
@@ -71,10 +73,7 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
7173
def video_to_pil_images_list(path: str,
7274
num_frames: int = -1) -> list[Image.Image]:
7375
frames = video_to_ndarrays(path, num_frames)
74-
return [
75-
Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
76-
for frame in frames
77-
]
76+
return [Image.fromarray(frame) for frame in frames]
7877

7978

8079
def video_get_metadata(path: str) -> dict[str, Any]:

0 commit comments

Comments
 (0)