Skip to content

Commit b7e8e4e

Browse files
[Bugfix] Always apply MM processor even when no MM items are passed (#26240)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 432e1cb commit b7e8e4e

File tree

6 files changed

+99
-27
lines changed

6 files changed

+99
-27
lines changed

tests/conftest.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from vllm.distributed import (cleanup_dist_env_and_memory,
4747
init_distributed_environment,
4848
initialize_model_parallel)
49-
from vllm.inputs import TextPrompt
5049
from vllm.logger import init_logger
5150
from vllm.logprobs import Logprob
5251
from vllm.multimodal.utils import fetch_image
@@ -760,35 +759,35 @@ def get_inputs(
760759
images: Optional[PromptImageInput] = None,
761760
videos: Optional[PromptVideoInput] = None,
762761
audios: Optional[PromptAudioInput] = None,
763-
) -> list[TextPrompt]:
764-
762+
) -> list[dict[str, Any]]:
765763
if any(x is not None and len(x) != len(prompts)
766764
for x in [images, videos, audios]):
767765
raise ValueError(
768766
"All non-None multimodal inputs must have the same length as "
769767
"prompts")
770768

771-
inputs = []
769+
inputs = list[dict[str, Any]]()
772770
for i, prompt in enumerate(prompts):
773-
multi_modal_data = {}
771+
prompt_dict = dict[str, Any]()
772+
if isinstance(prompt, str):
773+
prompt_dict["prompt"] = prompt
774+
elif isinstance(prompt, list):
775+
prompt_dict["prompt_token_ids"] = prompt
776+
else:
777+
prompt_dict["prompt_embeds"] = prompt
778+
779+
multi_modal_data = dict[str, Any]()
774780
if images is not None and (image := images[i]) is not None:
775781
multi_modal_data["image"] = image
776782
if videos is not None and (video := videos[i]) is not None:
777783
multi_modal_data["video"] = video
778784
if audios is not None and (audio := audios[i]) is not None:
779785
multi_modal_data["audio"] = audio
780786

781-
text_prompt_kwargs: dict[str, Any] = {
782-
"multi_modal_data": multi_modal_data or None
783-
}
784-
if isinstance(prompt, str):
785-
text_prompt_kwargs["prompt"] = prompt
786-
elif isinstance(prompt, list):
787-
text_prompt_kwargs["prompt_token_ids"] = prompt
788-
else:
789-
text_prompt_kwargs["prompt_embeds"] = prompt
787+
if multi_modal_data:
788+
prompt_dict["multi_modal_data"] = multi_modal_data
790789

791-
inputs.append(TextPrompt(**text_prompt_kwargs))
790+
inputs.append(prompt_dict)
792791

793792
return inputs
794793

tests/test_inputs.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33

44
import pytest
55

6+
from vllm.config import ModelConfig
67
from vllm.inputs import zip_enc_dec_prompts
78
from vllm.inputs.parse import parse_raw_prompts
9+
from vllm.inputs.preprocess import InputPreprocessor
10+
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
811

912
pytestmark = pytest.mark.cpu_test
1013

@@ -80,3 +83,50 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
8083
assert zipped['encoder_prompt'] == enc
8184
assert zipped['decoder_prompt'] == dec
8285
assert zipped['mm_processor_kwargs'] == exp_kwargs
86+
87+
88+
@pytest.mark.parametrize("model_id", [
89+
"facebook/opt-125m",
90+
])
91+
@pytest.mark.parametrize("prompt", [
92+
{
93+
"prompt": "",
94+
"multi_modal_data": {
95+
"dummy": []
96+
},
97+
},
98+
{
99+
"prompt_token_ids": [],
100+
"multi_modal_data": {
101+
"dummy": []
102+
},
103+
},
104+
])
105+
def test_preprocessor_text_no_mm_inputs(model_id, prompt):
106+
model_config = ModelConfig(model=model_id)
107+
tokenizer = init_tokenizer_from_configs(model_config)
108+
input_preprocessor = InputPreprocessor(model_config, tokenizer)
109+
110+
with pytest.raises(ValueError, match="does not support multimodal inputs"):
111+
input_preprocessor.preprocess(prompt)
112+
113+
114+
@pytest.mark.parametrize("model_id", [
115+
"facebook/chameleon-7b",
116+
])
117+
@pytest.mark.parametrize("prompt", [
118+
"",
119+
{
120+
"prompt_token_ids": []
121+
},
122+
])
123+
def test_preprocessor_always_mm_code_path(model_id, prompt):
124+
model_config = ModelConfig(model=model_id)
125+
tokenizer = init_tokenizer_from_configs(model_config)
126+
input_preprocessor = InputPreprocessor(model_config, tokenizer)
127+
128+
# HF processor adds sep token
129+
sep_token_id = tokenizer.vocab[tokenizer.sep_token]
130+
131+
processed_inputs = input_preprocessor.preprocess(prompt)
132+
assert sep_token_id in processed_inputs["prompt_token_ids"]

vllm/inputs/preprocess.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,15 +314,19 @@ def _process_tokens(
314314
parsed_content["prompt_token_ids"], tokenization_kwargs)
315315

316316
inputs: Union[TokenInputs, MultiModalInputs]
317-
if multi_modal_data := parsed_content.get("multi_modal_data"):
317+
if self.model_config.is_multimodal_model:
318318
inputs = self._process_multimodal(
319319
prompt_token_ids,
320-
multi_modal_data,
320+
parsed_content.get("multi_modal_data", {}),
321321
parsed_content.get("mm_processor_kwargs"),
322322
tokenization_kwargs=tokenization_kwargs,
323323
mm_uuids=mm_uuids,
324324
)
325325
else:
326+
if parsed_content.get("multi_modal_data"):
327+
raise ValueError(
328+
"This model does not support multimodal inputs")
329+
326330
inputs = token_inputs(prompt_token_ids)
327331

328332
if cache_salt := parsed_content.get("cache_salt"):
@@ -340,15 +344,19 @@ def _process_text(
340344
prompt_text = parsed_content["prompt"]
341345

342346
inputs: Union[TokenInputs, MultiModalInputs]
343-
if multi_modal_data := parsed_content.get("multi_modal_data"):
347+
if self.model_config.is_multimodal_model:
344348
inputs = self._process_multimodal(
345349
prompt_text,
346-
multi_modal_data,
350+
parsed_content.get("multi_modal_data", {}),
347351
parsed_content.get("mm_processor_kwargs"),
348352
tokenization_kwargs=tokenization_kwargs,
349353
mm_uuids=mm_uuids,
350354
)
351355
else:
356+
if parsed_content.get("multi_modal_data"):
357+
raise ValueError(
358+
"This model does not support multimodal inputs")
359+
352360
prompt_token_ids = self._tokenize_prompt(
353361
prompt_text,
354362
tokenization_kwargs=tokenization_kwargs,

vllm/model_executor/models/phi3v.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,8 +507,8 @@ def _apply_prompt_updates(
507507
)
508508

509509
# Keep the behavior in line with HF processor
510-
if token_ids[:2] == tokenizer.encode("<s> <|image|>",
511-
add_special_tokens=False):
510+
if len(mm_prompt_updates) and (token_ids[:2] == tokenizer.encode(
511+
"<s> <|image|>", add_special_tokens=False)):
512512
token_ids = [token_ids[0], *token_ids[2:]]
513513
placeholders = {
514514
modality: [

vllm/model_executor/models/qwen2_5_omni_thinker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def _maybe_apply_prompt_updates(
331331
"""
332332
mm_item_counts = mm_items.get_all_counts()
333333
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
334+
self._validate_mm_updates(mm_prompt_updates, mm_item_counts)
334335

335336
use_audio_in_video = False
336337
if "video" in mm_kwargs:

vllm/multimodal/processing.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1946,18 +1946,15 @@ def _validate_mm_kwargs(
19461946
"model (usually arising from an inconsistency between "
19471947
"`_call_hf_processor` and `_get_mm_fields_config`).")
19481948

1949-
def _validate_mm_placeholders(
1949+
def _validate_mm_updates(
19501950
self,
1951-
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
1951+
mm_updates: MultiModalPromptUpdates,
19521952
mm_item_counts: Mapping[str, int],
19531953
) -> None:
19541954
for modality, item_count in mm_item_counts.items():
1955-
placeholders = mm_placeholders.get(modality, [])
1955+
placeholders = mm_updates.get(modality, [])
19561956

19571957
if len(placeholders) != item_count:
1958-
# NOTE: If you are a model developer, this can also arise from
1959-
# an inconsistency between `_call_hf_processor` and
1960-
# `_get_mm_fields_config` implementations
19611958
raise RuntimeError(
19621959
f"Expected there to be {item_count} prompt updates "
19631960
f"corresponding to {item_count} {modality} items, but "
@@ -1967,6 +1964,22 @@ def _validate_mm_placeholders(
19671964
"in the prompt. If the model has a chat template, make "
19681965
"sure you have applied it before calling `LLM.generate`.")
19691966

1967+
def _validate_mm_placeholders(
1968+
self,
1969+
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
1970+
mm_item_counts: Mapping[str, int],
1971+
) -> None:
1972+
for modality, item_count in mm_item_counts.items():
1973+
placeholders = mm_placeholders.get(modality, [])
1974+
1975+
if len(placeholders) != item_count:
1976+
raise RuntimeError(
1977+
f"Expected there to be {item_count} prompt placeholders "
1978+
f"corresponding to {item_count} {modality} items, but "
1979+
f"instead found {len(placeholders)} prompt placeholders! "
1980+
"Make sure the implementation of `_call_hf_processor` and "
1981+
"`_get_mm_fields_config` are consistent with each other.")
1982+
19701983
def _maybe_apply_prompt_updates(
19711984
self,
19721985
mm_items: MultiModalDataItems,
@@ -1977,6 +1990,7 @@ def _maybe_apply_prompt_updates(
19771990
) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
19781991
mm_item_counts = mm_items.get_all_counts()
19791992
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
1993+
self._validate_mm_updates(mm_prompt_updates, mm_item_counts)
19801994

19811995
if is_update_applied:
19821996
mm_placeholders = self._find_mm_placeholders(

0 commit comments

Comments
 (0)