Skip to content

Commit 4ebc0b9

Browse files
[Bugfix] Proper input validation for multi-modal encoder-decoder models (#16156)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent dc96fd5 commit 4ebc0b9

File tree

10 files changed

+113
-62
lines changed

10 files changed

+113
-62
lines changed

examples/offline_inference/encoder_decoder_multimodal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def run_florence2():
5656
def run_mllama():
5757
engine_args = EngineArgs(
5858
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
59-
max_model_len=4096,
59+
max_model_len=8192,
6060
max_num_seqs=2,
6161
limit_mm_per_prompt={"image": 1},
6262
dtype="half",

examples/offline_inference/vision_language.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
556556
# The configuration below has been confirmed to launch on a single L40 GPU.
557557
engine_args = EngineArgs(
558558
model=model_name,
559-
max_model_len=4096,
559+
max_model_len=8192,
560560
max_num_seqs=2,
561561
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
562562
)

examples/offline_inference/vision_language_multi_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,8 @@ def load_mllama(question: str, image_urls: list[str]) -> ModelRequestData:
318318
# The configuration below has been confirmed to launch on a single L40 GPU.
319319
engine_args = EngineArgs(
320320
model=model_name,
321-
max_model_len=4096,
322-
max_num_seqs=16,
321+
max_model_len=8192,
322+
max_num_seqs=2,
323323
limit_mm_per_prompt={"image": len(image_urls)},
324324
)
325325

tests/engine/test_short_mm_context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
def test_context_length_too_short(vllm_runner, image_assets, model):
1919
images = [asset.pil_image for asset in image_assets]
2020

21-
with pytest.raises(ValueError, match="too long to fit into the model"):
21+
with pytest.raises(ValueError,
22+
match="longer than the maximum model length"):
2223
vllm_model = vllm_runner(
2324
model,
2425
max_model_len=128, # LLaVA has a feature size of 576

tests/entrypoints/llm/test_prompt_validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def v1(run_with_both_engines):
1515

1616
def test_empty_prompt():
1717
llm = LLM(model="openai-community/gpt2", enforce_eager=True)
18-
with pytest.raises(ValueError, match='Prompt cannot be empty'):
18+
with pytest.raises(ValueError, match='decoder prompt cannot be empty'):
1919
llm.generate([""])
2020

2121

tests/entrypoints/openai/test_prompt_validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ async def test_empty_prompt():
1717
client = remote_server.get_async_client()
1818

1919
with pytest.raises(openai.BadRequestError,
20-
match=re.compile('.+Prompt cannot be empty.+')):
20+
match="decoder prompt cannot be empty"):
2121
await client.completions.create(model=model_name,
2222
prompt="",
2323
max_tokens=5,

tests/models/encoder_decoder/vision_language/test_mllama.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def _run_test(
211211
# max_model_len should be greater than image_feature_size
212212
with vllm_runner(model,
213213
dtype=dtype,
214-
max_model_len=4096,
214+
max_model_len=8192,
215215
max_num_seqs=3,
216216
tensor_parallel_size=tensor_parallel_size,
217217
distributed_executor_backend=distributed_executor_backend,
@@ -422,7 +422,7 @@ def test_bnb_regression(
422422
llm = LLM(
423423
model=model,
424424
dtype=dtype,
425-
max_model_len=4096,
425+
max_model_len=8192,
426426
max_num_seqs=2,
427427
quantization="bitsandbytes",
428428
)
@@ -475,7 +475,7 @@ def test_explicit_implicit_prompt(
475475
llm = LLM(
476476
model=model,
477477
dtype=dtype,
478-
max_model_len=4096,
478+
max_model_len=8192,
479479
max_num_seqs=2,
480480
tensor_parallel_size=1,
481481
)
@@ -506,7 +506,7 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
506506
with global_force_attn_backend_context_manager(attn_backend), vllm_runner(
507507
model,
508508
dtype=dtype,
509-
max_model_len=4096,
509+
max_model_len=8192,
510510
max_num_seqs=2,
511511
tensor_parallel_size=1,
512512
limit_mm_per_prompt={"image":

vllm/engine/llm_engine.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from dataclasses import dataclass
99
from functools import partial
1010
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
11-
Iterable, List, Mapping, NamedTuple, Optional)
11+
Iterable, List, Literal, Mapping, NamedTuple, Optional)
1212
from typing import Sequence as GenericSequence
1313
from typing import Set, Type, Union, cast, overload
1414

@@ -30,7 +30,7 @@
3030
get_logits_processors as get_openai_logits_processors)
3131
from vllm.executor.executor_base import ExecutorBase
3232
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
33-
PromptType)
33+
PromptType, SingletonInputs)
3434
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
3535
from vllm.inputs.preprocess import InputPreprocessor
3636
from vllm.logger import init_logger
@@ -40,6 +40,7 @@
4040
get_local_guided_decoding_logits_processor)
4141
from vllm.model_executor.layers.sampler import SamplerOutput
4242
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
43+
from vllm.multimodal.processing import EncDecMultiModalProcessor
4344
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
4445
RequestOutputFactory)
4546
from vllm.pooling_params import PoolingParams
@@ -2029,29 +2030,57 @@ def _validate_model_inputs(self, inputs: ProcessorInputs,
20292030
lora_request: Optional[LoRARequest]):
20302031
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
20312032

2032-
# For encoder-decoder multimodal models, the max_prompt_len
2033-
# restricts the decoder prompt length
2034-
if self.model_config.is_multimodal_model:
2035-
prompt_inputs = decoder_inputs
2036-
else:
2037-
prompt_inputs = encoder_inputs or decoder_inputs
2033+
if encoder_inputs is not None:
2034+
self._validate_model_input(encoder_inputs,
2035+
lora_request,
2036+
prompt_type="encoder")
20382037

2039-
prompt_ids = prompt_inputs["prompt_token_ids"]
2038+
self._validate_model_input(decoder_inputs,
2039+
lora_request,
2040+
prompt_type="decoder")
20402041

2041-
if prompt_ids is None or len(prompt_ids) == 0:
2042-
raise ValueError("Prompt cannot be empty")
2042+
def _validate_model_input(
2043+
self,
2044+
prompt_inputs: SingletonInputs,
2045+
lora_request: Optional[LoRARequest],
2046+
*,
2047+
prompt_type: Literal["encoder", "decoder"],
2048+
):
2049+
if prompt_type == "encoder" and self.tokenizer is not None:
2050+
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
2051+
model_config = self.model_config
20432052

2044-
if self.model_config.is_multimodal_model:
2045-
max_prompt_len = self.model_config.max_model_len
2053+
if model_config.is_multimodal_model:
2054+
mm_registry = self.input_preprocessor.mm_registry
2055+
mm_processor = mm_registry.create_processor(
2056+
model_config, tokenizer=tokenizer)
2057+
assert isinstance(mm_processor, EncDecMultiModalProcessor)
20462058

2047-
if len(prompt_ids) > max_prompt_len:
2048-
raise ValueError(
2049-
f"The prompt (total length {len(prompt_ids)}) is too long "
2050-
f"to fit into the model (context length {max_prompt_len}). "
2059+
if mm_processor.pad_dummy_encoder_prompt:
2060+
return # Skip encoder length check for Whisper
2061+
2062+
prompt_ids = prompt_inputs["prompt_token_ids"]
2063+
2064+
if not prompt_ids:
2065+
raise ValueError(f"The {prompt_type} prompt cannot be empty")
2066+
2067+
max_prompt_len = self.model_config.max_model_len
2068+
if len(prompt_ids) >= max_prompt_len:
2069+
if self.model_config.is_multimodal_model:
2070+
suggestion = (
20512071
"Make sure that `max_model_len` is no smaller than the "
20522072
"number of text tokens plus multimodal tokens. For image "
20532073
"inputs, the number of image tokens depends on the number "
20542074
"of images, and possibly their aspect ratios as well.")
2075+
else:
2076+
suggestion = (
2077+
"Make sure that `max_model_len` is no smaller than the "
2078+
"number of text tokens.")
2079+
2080+
raise ValueError(
2081+
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
2082+
f"longer than the maximum model length of {max_prompt_len}. "
2083+
f"{suggestion}")
20552084

20562085
# TODO: Find out how many placeholder tokens are there so we can
20572086
# check that chunked prefill does not truncate them

vllm/multimodal/profiling.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,12 @@ def get_encoder_dummy_data(
213213

214214
total_len = len(encoder_prompt_token_ids)
215215

216-
# Encoder-decoder multimodal models only support v0
217-
if total_len > seq_len:
216+
processor = cast(EncDecMultiModalProcessor, self.processor)
217+
if processor.pad_dummy_encoder_prompt:
218+
num_tokens_to_pad = max(total_len, seq_len) - total_len
219+
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
220+
# NOTE: Whisper allows total_len > seq_len.
221+
elif total_len > seq_len and not envs.VLLM_USE_V1:
218222
# `max_num_batched_tokens` is defined by `SchedulerConfig`
219223
logger.warning_once(
220224
"The encoder sequence length used for profiling ("
@@ -229,11 +233,6 @@ def get_encoder_dummy_data(
229233
"increase `max_model_len`, reduce `max_num_seqs`, "
230234
"and/or reduce `mm_counts`.")
231235

232-
processor = cast(EncDecMultiModalProcessor, self.processor)
233-
if processor.pad_dummy_encoder_prompt:
234-
num_tokens_to_pad = max(total_len, seq_len) - total_len
235-
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
236-
237236
return DummyEncoderData(encoder_prompt_token_ids)
238237

239238
def get_decoder_dummy_data(

vllm/v1/engine/processor.py

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22

33
import time
44
from collections.abc import Mapping
5-
from typing import Optional, Union
5+
from typing import Literal, Optional, Union
66

77
from vllm.config import VllmConfig
8-
from vllm.inputs import ProcessorInputs, PromptType
8+
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
99
from vllm.inputs.parse import split_enc_dec_inputs
1010
from vllm.inputs.preprocess import InputPreprocessor
1111
from vllm.lora.request import LoRARequest
1212
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
1313
MultiModalRegistry)
1414
from vllm.multimodal.inputs import PlaceholderRange
15+
from vllm.multimodal.processing import EncDecMultiModalProcessor
1516
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
1617
from vllm.pooling_params import PoolingParams
1718
from vllm.prompt_adapter.request import PromptAdapterRequest
@@ -287,41 +288,62 @@ def _validate_model_inputs(self,
287288
lora_request: Optional[LoRARequest] = None):
288289
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
289290

290-
# For encoder-decoder multimodal models, the max_prompt_len
291-
# restricts the decoder prompt length
292-
if self.model_config.is_multimodal_model:
293-
prompt_inputs = decoder_inputs
294-
else:
295-
prompt_inputs = encoder_inputs or decoder_inputs
291+
if encoder_inputs is not None:
292+
self._validate_model_input(encoder_inputs,
293+
lora_request,
294+
prompt_type="encoder")
296295

297-
prompt_ids = prompt_inputs["prompt_token_ids"]
296+
self._validate_model_input(decoder_inputs,
297+
lora_request,
298+
prompt_type="decoder")
298299

299-
if prompt_ids is None or len(prompt_ids) == 0:
300-
raise ValueError("Prompt cannot be empty")
300+
def _validate_model_input(
301+
self,
302+
prompt_inputs: SingletonInputs,
303+
lora_request: Optional[LoRARequest],
304+
*,
305+
prompt_type: Literal["encoder", "decoder"],
306+
):
307+
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
301308

302-
max_input_id = max(prompt_ids)
303-
max_allowed = self.tokenizer.get_lora_tokenizer(
304-
lora_request).max_token_id
305-
if max_input_id > max_allowed:
306-
raise ValueError(
307-
"Token id {} is out of vocabulary".format(max_input_id))
309+
if prompt_type == "encoder":
310+
model_config = self.model_config
308311

309-
if len(prompt_ids) >= self.model_config.max_model_len:
310-
raise ValueError(
311-
f"Prompt length of {len(prompt_ids)} is longer than the "
312-
f"maximum model length of {self.model_config.max_model_len}.")
312+
if model_config.is_multimodal_model:
313+
mm_registry = self.input_preprocessor.mm_registry
314+
mm_processor = mm_registry.create_processor(
315+
model_config, tokenizer=tokenizer)
316+
assert isinstance(mm_processor, EncDecMultiModalProcessor)
313317

314-
if self.model_config.is_multimodal_model:
315-
max_prompt_len = self.model_config.max_model_len
318+
if mm_processor.pad_dummy_encoder_prompt:
319+
return # Skip encoder length check for Whisper
316320

317-
if len(prompt_ids) > max_prompt_len:
318-
raise ValueError(
319-
f"The prompt (total length {len(prompt_ids)}) is too long "
320-
f"to fit into the model (context length {max_prompt_len}). "
321+
prompt_ids = prompt_inputs["prompt_token_ids"]
322+
323+
if not prompt_ids:
324+
raise ValueError(f"The {prompt_type} prompt cannot be empty")
325+
326+
max_input_id = max(prompt_ids)
327+
if max_input_id > tokenizer.max_token_id:
328+
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
329+
330+
max_prompt_len = self.model_config.max_model_len
331+
if len(prompt_ids) >= max_prompt_len:
332+
if self.model_config.is_multimodal_model:
333+
suggestion = (
321334
"Make sure that `max_model_len` is no smaller than the "
322335
"number of text tokens plus multimodal tokens. For image "
323336
"inputs, the number of image tokens depends on the number "
324337
"of images, and possibly their aspect ratios as well.")
338+
else:
339+
suggestion = (
340+
"Make sure that `max_model_len` is no smaller than the "
341+
"number of text tokens.")
342+
343+
raise ValueError(
344+
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
345+
f"longer than the maximum model length of {max_prompt_len}. "
346+
f"{suggestion}")
325347

326348
# TODO: Find out how many placeholder tokens are there so we can
327349
# check that chunked prefill does not truncate them

0 commit comments

Comments
 (0)