Skip to content

[Core] Use individual MM items in P0/P1 cache and model runner #22570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Aug 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 79 additions & 154 deletions tests/multimodal/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import mimetypes
import os
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import TYPE_CHECKING, NamedTuple, Optional
from typing import TYPE_CHECKING, NamedTuple

import numpy as np
import pytest
Expand All @@ -19,14 +19,12 @@
initialize_model_parallel)
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (MediaConnector,
merge_and_sort_multimodal_metadata,
from vllm.multimodal.utils import (MediaConnector, argsort_mm_positions,
run_dp_sharded_vision_model)
from vllm.platforms import current_platform
from vllm.utils import get_open_port, update_environment_variables

if TYPE_CHECKING:
from vllm.multimodal.hasher import MultiModalHashDict
from vllm.multimodal.inputs import MultiModalPlaceholderDict

# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
Expand Down Expand Up @@ -178,54 +176,45 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
assert metadata_sync == metadata_async


# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.
# Used for `test_argsort_mm_positions`.
class TestCase(NamedTuple):
mm_positions: "MultiModalPlaceholderDict"
mm_hashes: Optional["MultiModalHashDict"]
expected_modalities: list[str]
expected_ranges: list[PlaceholderRange]
expected_hashes: Optional[list[str]]
expected_modality_idxs: list[tuple[str, int]]


def test_merge_and_sort_multimodal_metadata():
def test_argsort_mm_positions():

test_cases = [
# Single modality should return result as is but flattened
# Single modality
## Internally sorted
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=2),
]
},
mm_hashes={"image": ["hash1", "hash2"]},
expected_modalities=["image", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=2),
expected_modality_idxs=[
("image", 0),
("image", 1),
],
expected_hashes=["hash1", "hash2"],
),

# Single modality without hashes return None for mm hash.
## Internally unsorted
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=3, length=2),
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=2),
]
},
mm_hashes=None,
expected_modalities=["image", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=2),
expected_modality_idxs=[
("image", 1),
("image", 0),
],
expected_hashes=None,
),

# Multiple modalities with hashes should return sorted modalities
# and flattened ranges and hashes.
# Two modalities
## Internally sorted
TestCase(
mm_positions={
"image": [
Expand All @@ -237,47 +226,54 @@ def test_merge_and_sort_multimodal_metadata():
PlaceholderRange(offset=2, length=3),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1", "audio_hash2"],
},
expected_modalities=["audio", "audio", "image", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
expected_modality_idxs=[
("audio", 0),
("audio", 1),
("image", 0),
("image", 1),
],
expected_hashes=[
"audio_hash1", "audio_hash2", "image_hash1", "image_hash2"
),
## Interleaved, internally sorted
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=8, length=2),
],
"audio": [
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=11, length=4),
]
},
expected_modality_idxs=[
("image", 0),
("audio", 0),
("image", 1),
("audio", 1),
],
),

# Multiple modalities without hashes should return sorted modalities
# and flattened ranges and None.
## Interleaved, internally unsorted
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
PlaceholderRange(offset=8, length=2),
PlaceholderRange(offset=0, length=4),
],
"audio": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=11, length=4),
PlaceholderRange(offset=5, length=2),
]
},
mm_hashes=None,
expected_modalities=["audio", "audio", "image", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=7, length=4),
PlaceholderRange(offset=11, length=5),
expected_modality_idxs=[
("image", 1),
("audio", 1),
("image", 0),
("audio", 0),
],
expected_hashes=None,
),

# Three modalities
## Internally sorted
TestCase(
mm_positions={
"image": [
Expand All @@ -293,72 +289,16 @@ def test_merge_and_sort_multimodal_metadata():
PlaceholderRange(offset=12, length=6),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1"],
"video": ["video_hash1", "video_hash2", "video_hash3"]
},
expected_modalities=[
"audio", "video", "video", "video", "image", "image"
],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=3, length=4),
PlaceholderRange(offset=7, length=5),
PlaceholderRange(offset=12, length=6),
PlaceholderRange(offset=15, length=7),
PlaceholderRange(offset=22, length=8),
],
expected_hashes=[
"audio_hash1", "video_hash1", "video_hash2", "video_hash3",
"image_hash1", "image_hash2"
],
),
]

for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
expected_hashes) in test_cases:
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
mm_positions, mm_hashes)

assert modalities == expected_modalities
assert ranges == expected_ranges
assert hashes == expected_hashes


def test_merge_and_sort_multimodal_metadata_with_interleaving():

test_cases = [

# <image> <audio> <image> <audio>
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=8, length=2),
],
"audio": [
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=11, length=4),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1", "audio_hash2"],
},
expected_modalities=["image", "audio", "image", "audio"],
expected_ranges=[
PlaceholderRange(offset=0, length=4),
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=8, length=2),
PlaceholderRange(offset=11, length=4),
],
expected_hashes=[
"image_hash1", "audio_hash1", "image_hash2", "audio_hash2"
expected_modality_idxs=[
("audio", 0),
("video", 0),
("video", 1),
("video", 2),
("image", 0),
("image", 1),
],
),

# <image> <image> <audio> <video> <image>
## Interleaved, internally sorted
TestCase(
mm_positions={
"image": [
Expand All @@ -373,58 +313,43 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
PlaceholderRange(offset=8, length=5),
]
},
mm_hashes=None,
expected_modalities=["image", "image", "audio", "video", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=2, length=3),
PlaceholderRange(offset=5, length=2),
PlaceholderRange(offset=8, length=5),
PlaceholderRange(offset=20, length=4),
expected_modality_idxs=[
("image", 0),
("image", 1),
("audio", 0),
("video", 0),
("image", 2),
],
expected_hashes=None,
),

# <image> <audio> <video> <image> with hashes
## Interleaved, internally sunorted
TestCase(
mm_positions={
"image": [
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=18, length=4),
PlaceholderRange(offset=20, length=4),
PlaceholderRange(offset=2, length=3),
],
"audio": [
PlaceholderRange(offset=6, length=2),
PlaceholderRange(offset=5, length=2),
],
"video": [
PlaceholderRange(offset=10, length=5),
PlaceholderRange(offset=8, length=5),
]
},
mm_hashes={
"image": ["image_hash1", "image_hash2"],
"audio": ["audio_hash1"],
"video": ["video_hash1"],
},
expected_modalities=["image", "audio", "video", "image"],
expected_ranges=[
PlaceholderRange(offset=0, length=2),
PlaceholderRange(offset=6, length=2),
PlaceholderRange(offset=10, length=5),
PlaceholderRange(offset=18, length=4),
],
expected_hashes=[
"image_hash1", "audio_hash1", "video_hash1", "image_hash2"
expected_modality_idxs=[
("image", 0),
("image", 2),
("audio", 0),
("video", 0),
("image", 1),
],
),
]

for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
expected_hashes) in test_cases:
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
mm_positions, mm_hashes)
for mm_positions, expected_modality_idxs in test_cases:
modality_idxs = argsort_mm_positions(mm_positions)

assert modalities == expected_modalities
assert ranges == expected_ranges
assert hashes == expected_hashes
assert modality_idxs == expected_modality_idxs


class SimpleLinearModel(torch.nn.Module):
Expand Down
Loading