Skip to content

Commit 79aa244

Browse files
wwl2755hmellor
andauthored
[Multi Modal] Configurable MM Profiling (#25631)
Signed-off-by: wwl2755 <[email protected]> Signed-off-by: Harry Mellor <[email protected]> Co-authored-by: Harry Mellor <[email protected]>
1 parent 2ed3f20 commit 79aa244

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+654
-99
lines changed

docs/contributing/model/multimodal.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,17 +258,21 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
258258
self,
259259
seq_len: int,
260260
mm_counts: Mapping[str, int],
261+
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
261262
) -> MultiModalDataDict:
262263
num_images = mm_counts.get("image", 0)
263264

264265
target_width, target_height = \
265266
self.info.get_image_size_with_most_features()
266267

268+
image_overrides = mm_options.get("image") if mm_options else None
269+
267270
return {
268271
"image":
269272
self._get_dummy_images(width=target_width,
270273
height=target_height,
271-
num_images=num_images)
274+
num_images=num_images,
275+
overrides=image_overrides)
272276
}
273277
```
274278

@@ -438,16 +442,20 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
438442
self,
439443
seq_len: int,
440444
mm_counts: Mapping[str, int],
445+
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
441446
) -> MultiModalDataDict:
442447
target_width, target_height = \
443448
self.info.get_image_size_with_most_features()
444449
num_images = mm_counts.get("image", 0)
445450

451+
image_overrides = mm_options.get("image") if mm_options else None
452+
446453
return {
447454
"image":
448455
self._get_dummy_images(width=target_width,
449456
height=target_height,
450-
num_images=num_images)
457+
num_images=num_images,
458+
overrides=image_overrides)
451459
}
452460
```
453461

tests/models/multimodal/processing/test_common.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from PIL import Image
1313

1414
from vllm.config import ModelConfig
15+
from vllm.config.multimodal import (AudioDummyOptions, BaseDummyOptions,
16+
ImageDummyOptions, VideoDummyOptions)
1517
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
1618
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
1719
from vllm.multimodal.inputs import MultiModalInputs
@@ -112,12 +114,26 @@ def _test_processing_correctness(
112114

113115
processing_info = factories.info(ctx)
114116
supported_mm_limits = processing_info.get_supported_mm_limits()
115-
limit_mm_per_prompt = {
117+
# Keep integer limits for local data generation
118+
limit_mm_per_prompt_ints = {
116119
modality: 3 if limit is None else limit
117120
for modality, limit in supported_mm_limits.items()
118121
}
119122

120-
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
123+
def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions:
124+
if modality == "video":
125+
return VideoDummyOptions(count=count)
126+
if modality == "image":
127+
return ImageDummyOptions(count=count)
128+
if modality == "audio":
129+
return AudioDummyOptions(count=count)
130+
return BaseDummyOptions(count=count)
131+
132+
# Assign normalized DummyOptions to the model config
133+
model_config.get_multimodal_config().limit_per_prompt = {
134+
modality: _to_dummy_options(modality, count)
135+
for modality, count in limit_mm_per_prompt_ints.items()
136+
}
121137

122138
baseline_processor = factories.build_processor(ctx, cache=None)
123139
cached_processor = factories.build_processor(ctx, cache=cache)
@@ -150,7 +166,7 @@ def _test_processing_correctness(
150166
k:
151167
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
152168
for _ in range(rng.randint(limit + 1))]
153-
for k, limit in limit_mm_per_prompt.items()
169+
for k, limit in limit_mm_per_prompt_ints.items()
154170
}
155171

156172
mm_counts = {k: len(vs) for k, vs in mm_data.items()}

tests/models/multimodal/processing/test_mllama4.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,23 @@ def test_profiling(model_id: str, max_model_len: int):
1717
model_config_kwargs = {
1818
"max_model_len": max_model_len,
1919
}
20+
mm_counts = {"image": 1}
2021
ctx = build_model_context(
2122
model_id,
2223
model_config_kwargs=model_config_kwargs,
23-
limit_mm_per_prompt={"image": 1},
24+
limit_mm_per_prompt=mm_counts,
2425
)
2526

26-
mm_config = ctx.get_mm_config()
2727
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
2828
profiler = MultiModalProfiler(processor)
2929

3030
decoder_dummy_data = profiler.get_decoder_dummy_data(
3131
max_model_len,
32-
mm_counts=mm_config.limit_per_prompt,
32+
mm_counts=mm_counts,
3333
)
3434
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
3535
max_model_len,
36-
mm_counts=mm_config.limit_per_prompt,
36+
mm_counts=mm_counts,
3737
)
3838

3939
hf_config = ctx.get_hf_config(Llama4Config)
@@ -58,7 +58,7 @@ def test_profiling(model_id: str, max_model_len: int):
5858

5959
profiled_tokens = profiler.get_mm_max_contiguous_tokens(
6060
max_model_len,
61-
mm_counts=mm_config.limit_per_prompt,
61+
mm_counts=mm_counts,
6262
)
6363

6464
assert total_tokens == profiled_tokens["image"]

tests/models/multimodal/processing/test_tensor_schema.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from PIL import Image
1616

1717
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
18+
from vllm.config.multimodal import (AudioDummyOptions, BaseDummyOptions,
19+
ImageDummyOptions, VideoDummyOptions)
1820
from vllm.distributed import (cleanup_dist_env_and_memory,
1921
init_distributed_environment,
2022
initialize_model_parallel)
@@ -236,7 +238,20 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
236238
modality: 3 if limit is None else limit
237239
for modality, limit in supported_mm_limits.items()
238240
}
239-
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
241+
242+
def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions:
243+
if modality == "video":
244+
return VideoDummyOptions(count=count)
245+
if modality == "image":
246+
return ImageDummyOptions(count=count)
247+
if modality == "audio":
248+
return AudioDummyOptions(count=count)
249+
return BaseDummyOptions(count=count)
250+
251+
model_config.get_multimodal_config().limit_per_prompt = {
252+
modality: _to_dummy_options(modality, count)
253+
for modality, count in limit_mm_per_prompt.items()
254+
}
240255
processor = factories.build_processor(ctx, cache=None)
241256

242257
with initialize_dummy_model(model_cls, model_config) as model:

vllm/config/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,9 @@ class ModelConfig:
276276
multimodal_config: Optional[MultiModalConfig] = None
277277
"""Configuration for multimodal model. If `None`, this will be inferred
278278
from the architecture of `self.model`."""
279-
limit_mm_per_prompt: InitVar[Optional[dict[str, int]]] = None
279+
limit_mm_per_prompt: InitVar[Optional[dict[str, Union[int,
280+
dict[str,
281+
int]]]]] = None
280282
media_io_kwargs: InitVar[Optional[dict[str, dict[str, Any]]]] = None
281283
mm_processor_kwargs: InitVar[Optional[dict[str, Any]]] = None
282284
mm_processor_cache_gb: InitVar[Optional[float]] = None

vllm/config/multimodal.py

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,68 @@
44
import hashlib
55
from collections.abc import Mapping
66
from dataclasses import field
7-
from typing import Any, Literal, Optional
7+
from typing import Any, Literal, Optional, Union
88

9+
from pydantic import ConfigDict, Field, field_validator
910
from pydantic.dataclasses import dataclass
1011

11-
import vllm.envs as envs
1212
from vllm.config.utils import config
1313

14+
15+
@dataclass
16+
class BaseDummyOptions:
17+
"""Base options for generating dummy data during profiling."""
18+
count: int = Field(999, ge=0)
19+
20+
21+
@dataclass(config=ConfigDict(extra="forbid"))
22+
class VideoDummyOptions(BaseDummyOptions):
23+
"""Options for generating dummy video data during profiling."""
24+
num_frames: Optional[int] = Field(None, gt=0)
25+
width: Optional[int] = Field(None, gt=0)
26+
height: Optional[int] = Field(None, gt=0)
27+
28+
29+
@dataclass(config=ConfigDict(extra="forbid"))
30+
class ImageDummyOptions(BaseDummyOptions):
31+
"""Options for generating dummy image data during profiling."""
32+
width: Optional[int] = Field(None, gt=0)
33+
height: Optional[int] = Field(None, gt=0)
34+
35+
36+
@dataclass(config=ConfigDict(extra="forbid"))
37+
class AudioDummyOptions(BaseDummyOptions):
38+
"""Options for generating dummy audio data during profiling."""
39+
length: Optional[int] = Field(None, gt=0)
40+
41+
1442
MMEncoderTPMode = Literal["weights", "data"]
1543
MMCacheType = Literal["shm", "lru"]
44+
DummyOptions = Union[BaseDummyOptions, VideoDummyOptions, ImageDummyOptions,
45+
AudioDummyOptions]
1646

1747

1848
@config
1949
@dataclass
2050
class MultiModalConfig:
2151
"""Controls the behavior of multimodal models."""
2252

23-
limit_per_prompt: dict[str, int] = field(default_factory=dict)
24-
"""The maximum number of input items allowed per prompt for each modality.
25-
Defaults to 1 (V0) or 999 (V1) for each modality.
53+
limit_per_prompt: dict[str, DummyOptions] = field(default_factory=dict)
54+
"""The maximum number of input items and options allowed per
55+
prompt for each modality.
56+
Defaults to 999 for each modality.
57+
58+
Legacy format (count only):
59+
{"image": 16, "video": 2}
60+
61+
Configurable format (with options):
62+
{"video": {"count": 1, "num_frames": 32, "width": 512, "height": 512},
63+
"image": {"count": 5, "width": 512, "height": 512}}
2664
27-
For example, to allow up to 16 images and 2 videos per prompt:
28-
`{"image": 16, "video": 2}`"""
65+
Mixed format (combining both):
66+
{"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512,
67+
"height": 512}}
68+
"""
2969
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
3070
"""Additional args passed to process media inputs, keyed by modalities.
3171
For example, to set num_frames for video, set
@@ -84,6 +124,27 @@ class MultiModalConfig:
84124
from each video to be pruned.
85125
"""
86126

127+
@field_validator("limit_per_prompt", mode="before")
128+
@classmethod
129+
def _validate_limit_per_prompt(
130+
cls, value: dict[str, Union[int,
131+
dict[str,
132+
int]]]) -> dict[str, DummyOptions]:
133+
for k, v in value.items():
134+
# Handle legacy format where only count is specified
135+
if isinstance(v, int):
136+
v = {"count": v}
137+
# Convert to the appropriate DummyOptions subclass
138+
if k == "video":
139+
value[k] = VideoDummyOptions(**v)
140+
elif k == "image":
141+
value[k] = ImageDummyOptions(**v)
142+
elif k == "audio":
143+
value[k] = AudioDummyOptions(**v)
144+
else:
145+
value[k] = BaseDummyOptions(**v)
146+
return value
147+
87148
def compute_hash(self) -> str:
88149
"""
89150
WARNING: Whenever a new field is added to this config,
@@ -106,12 +167,22 @@ def compute_hash(self) -> str:
106167
def get_limit_per_prompt(self, modality: str) -> int:
107168
"""
108169
Get the maximum number of input items allowed per prompt
109-
for the given modality.
170+
for the given modality (backward compatible).
171+
"""
172+
limit_data = self.limit_per_prompt.get(modality)
173+
174+
if limit_data is None:
175+
# Unspecified modality is set to 999 by default
176+
return 999
177+
return limit_data.count
178+
179+
def get_dummy_options(self, modality: str) -> Optional[BaseDummyOptions]:
180+
"""
181+
Get the configurable dummy data options for a modality.
182+
Returns None if no options are configured for this modality.
110183
"""
111-
return self.limit_per_prompt.get(
112-
modality,
113-
999 if envs.VLLM_USE_V1 else 1,
114-
)
184+
# All values are now DummyOptions after normalization
185+
return self.limit_per_prompt.get(modality)
115186

116187
def merge_mm_processor_kwargs(
117188
self,

vllm/engine/arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ class EngineArgs:
376376
quantization: Optional[QuantizationMethods] = ModelConfig.quantization
377377
enforce_eager: bool = ModelConfig.enforce_eager
378378
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
379-
limit_mm_per_prompt: dict[str, int] = \
379+
limit_mm_per_prompt: dict[str, Union[int, dict[str, int]]] = \
380380
get_field(MultiModalConfig, "limit_per_prompt")
381381
interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
382382
media_io_kwargs: dict[str, dict[str,

vllm/model_executor/models/aria.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from transformers.models.aria.processing_aria import AriaProcessor
1111

1212
from vllm.config import VllmConfig
13+
from vllm.config.multimodal import BaseDummyOptions
1314
from vllm.distributed import get_tensor_model_parallel_rank
1415
from vllm.model_executor.layers.activation import get_act_fn
1516
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -431,17 +432,21 @@ def get_dummy_mm_data(
431432
self,
432433
seq_len: int,
433434
mm_counts: Mapping[str, int],
435+
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
434436
) -> MultiModalDataDict:
435437
vision_config = self.info.get_vision_config()
436438

437439
max_image_size = vision_config.image_size
438440
num_images = mm_counts.get("image", 0)
439441

442+
image_overrides = mm_options.get("image") if mm_options else None
443+
440444
return {
441445
"image":
442446
self._get_dummy_images(width=max_image_size,
443447
height=max_image_size,
444-
num_images=num_images)
448+
num_images=num_images,
449+
overrides=image_overrides)
445450
}
446451

447452

vllm/model_executor/models/aya_vision.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_optimal_tiled_canvas)
1717

1818
from vllm.config import VllmConfig
19+
from vllm.config.multimodal import BaseDummyOptions
1920
from vllm.multimodal import MULTIMODAL_REGISTRY
2021
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
2122
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
@@ -166,16 +167,20 @@ def get_dummy_mm_data(
166167
self,
167168
seq_len: int,
168169
mm_counts: Mapping[str, int],
170+
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
169171
) -> MultiModalDataDict:
170172
num_images = mm_counts.get("image", 0)
171173
image_size = \
172174
self.info.get_image_size_with_most_features()
173175

176+
image_overrides = mm_options.get("image") if mm_options else None
177+
174178
return {
175179
"image":
176180
self._get_dummy_images(width=image_size.width,
177181
height=image_size.height,
178-
num_images=num_images)
182+
num_images=num_images,
183+
overrides=image_overrides)
179184
}
180185

181186

0 commit comments

Comments
 (0)