Skip to content

Commit f5d0f47

Browse files
[Frontend] Improve error message for too many mm items (#22114)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent b690e34 commit f5d0f47

File tree

5 files changed

+52
-51
lines changed

5 files changed

+52
-51
lines changed

tests/entrypoints/test_chat_utils.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -579,10 +579,7 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message(
579579
warnings.filterwarnings(
580580
"ignore",
581581
message="coroutine 'async_get_and_parse_image' was never awaited")
582-
with pytest.raises(
583-
ValueError,
584-
match="At most 2 image\\(s\\) may be provided in one request\\."
585-
):
582+
with pytest.raises(ValueError, match="At most"):
586583
parse_chat_messages(
587584
[{
588585
"role":
@@ -622,10 +619,7 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
622619
warnings.filterwarnings(
623620
"ignore",
624621
message="coroutine 'async_get_and_parse_image' was never awaited")
625-
with pytest.raises(
626-
ValueError,
627-
match="At most 2 image\\(s\\) may be provided in one request\\."
628-
):
622+
with pytest.raises(ValueError, match="At most"):
629623
parse_chat_messages(
630624
[{
631625
"role":

tests/multimodal/test_processing.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from contextlib import nullcontext
55
from typing import Optional, cast
6-
from unittest.mock import MagicMock
76

87
import numpy as np
98
import pytest
@@ -957,15 +956,14 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
957956
)
958957

959958
processor = MULTIMODAL_REGISTRY.create_processor(model_config)
960-
profiler = MultiModalProfiler(processor)
959+
processor._supported_mm_limits = {"image": num_supported}
961960

962-
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
963-
processor.info.get_supported_mm_limits = mock_supported_mm_limits
961+
profiler = MultiModalProfiler(processor)
964962

965963
if is_valid:
966964
exc_ctx = nullcontext()
967965
else:
968-
exc_ctx = pytest.raises(ValueError, match="The model only supports")
966+
exc_ctx = pytest.raises(ValueError, match="At most")
969967

970968
with exc_ctx:
971969
profiler.get_decoder_dummy_data(
@@ -1002,7 +1000,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
10021000
if is_valid:
10031001
exc_ctx = nullcontext()
10041002
else:
1005-
exc_ctx = pytest.raises(ValueError, match=f"passed {num_images} image")
1003+
exc_ctx = pytest.raises(ValueError, match="At most")
10061004

10071005
with exc_ctx:
10081006
processor.apply(

vllm/entrypoints/chat_utils.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -535,9 +535,10 @@ def model_config(self) -> ModelConfig:
535535
return self._model_config
536536

537537
@cached_property
538-
def model_cls(self):
538+
def model_cls(self) -> type[SupportsMultiModal]:
539539
from vllm.model_executor.model_loader import get_model_cls
540-
return get_model_cls(self.model_config)
540+
model_cls = get_model_cls(self.model_config)
541+
return cast(type[SupportsMultiModal], model_cls)
541542

542543
@property
543544
def allowed_local_media_path(self):
@@ -547,31 +548,23 @@ def allowed_local_media_path(self):
547548
def mm_registry(self):
548549
return MULTIMODAL_REGISTRY
549550

551+
@cached_property
552+
def mm_processor(self):
553+
return self.mm_registry.create_processor(self.model_config)
554+
550555
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
551556
"""
552557
Add a multi-modal item to the current prompt and returns the
553558
placeholder string to use, if any.
554559
"""
555-
mm_registry = self.mm_registry
556-
model_config = self.model_config
557-
model_cls = cast(SupportsMultiModal, self.model_cls)
558-
559560
input_modality = modality.replace("_embeds", "")
561+
num_items = len(self._items_by_modality[modality]) + 1
560562

561-
mm_processor = mm_registry.create_processor(model_config)
562-
allowed_counts = mm_processor.info.get_allowed_mm_limits()
563-
allowed_count = allowed_counts.get(input_modality, 0)
564-
565-
current_count = len(self._items_by_modality[modality]) + 1
566-
if current_count > allowed_count:
567-
raise ValueError(
568-
f"At most {allowed_count} {modality}(s) may be provided in "
569-
"one request. You can set `--limit-mm-per-prompt` to "
570-
"increase this limit if the model supports it.")
563+
self.mm_processor.validate_num_items(input_modality, num_items)
571564

572565
self._items_by_modality[modality].append(item)
573566

574-
return model_cls.get_placeholder_str(modality, current_count)
567+
return self.model_cls.get_placeholder_str(modality, num_items)
575568

576569
@abstractmethod
577570
def create_parser(self) -> "BaseMultiModalContentParser":

vllm/multimodal/processing.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
import json
43
import sys
54
from abc import ABC, abstractmethod
65
from collections import defaultdict
@@ -1156,6 +1155,18 @@ def __init__(self,
11561155

11571156
self.data_parser = self._get_data_parser()
11581157

1158+
# Avoid unnecessary recomputation
1159+
self._supported_mm_limits = self.info.get_supported_mm_limits()
1160+
self._allowed_mm_limits = self.info.get_allowed_mm_limits()
1161+
1162+
@property
1163+
def supported_mm_limits(self):
1164+
return self._supported_mm_limits
1165+
1166+
@property
1167+
def allowed_mm_limits(self):
1168+
return self._allowed_mm_limits
1169+
11591170
def __call__(
11601171
self,
11611172
prompt: str,
@@ -1176,6 +1187,28 @@ def _get_data_parser(self) -> MultiModalDataParser:
11761187
"""
11771188
return MultiModalDataParser()
11781189

1190+
def validate_num_items(
1191+
self,
1192+
modality: str,
1193+
num_items: int,
1194+
) -> None:
1195+
supported_limit = self.supported_mm_limits.get(modality, 0)
1196+
allowed_limit = self.allowed_mm_limits.get(modality, 0)
1197+
1198+
if supported_limit is None:
1199+
supported_limit = allowed_limit
1200+
1201+
limit = min(supported_limit, allowed_limit)
1202+
1203+
if num_items > limit:
1204+
msg = (f"At most {limit} {modality}(s) may be provided in "
1205+
"one prompt.")
1206+
1207+
if num_items <= supported_limit:
1208+
msg += " Set `--limit-mm-per-prompt` to increase this limit."
1209+
1210+
raise ValueError(msg)
1211+
11791212
def _to_mm_items(
11801213
self,
11811214
mm_data: MultiModalDataDict,
@@ -1188,26 +1221,9 @@ def _to_mm_items(
11881221
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
11891222
"""
11901223
mm_items = self.data_parser.parse_mm_data(mm_data)
1191-
supported_mm_limits = self.info.get_supported_mm_limits()
1192-
allowed_mm_limits = self.info.get_allowed_mm_limits()
11931224

11941225
for modality, items in mm_items.items():
1195-
supported_limit = supported_mm_limits.get(modality, 0)
1196-
allowed_limit = allowed_mm_limits.get(modality, 0)
1197-
num_items = len(items)
1198-
1199-
if supported_limit is not None and num_items > supported_limit:
1200-
raise ValueError(
1201-
f"The model only supports at most {supported_limit} "
1202-
f"{modality} items, but you passed {num_items} "
1203-
f"{modality} items in the same prompt.")
1204-
1205-
if num_items > allowed_limit:
1206-
raise ValueError(
1207-
"You set or defaulted to "
1208-
f"'{json.dumps({modality: allowed_limit})}' in "
1209-
f"`--limit-mm-per-prompt`, but passed {num_items} "
1210-
f"{modality} items in the same prompt.")
1226+
self.validate_num_items(modality, len(items))
12111227

12121228
return mm_items
12131229

vllm/multimodal/profiling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
156156
return self.processor.dummy_inputs
157157

158158
def get_mm_limits(self) -> Mapping[str, int]:
159-
return self.processing_info.get_allowed_mm_limits()
159+
return self.processor.allowed_mm_limits
160160

161161
def _get_dummy_mm_inputs(
162162
self,

0 commit comments

Comments
 (0)