Skip to content

Commit 2471815

Browse files
[Misc] Replace is_encoder_decoder_inputs with split_enc_dec_inputs (#15620)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 07bf813 commit 2471815

File tree

8 files changed

+47
-52
lines changed

8 files changed

+47
-52
lines changed

tests/models/multimodal/processing/test_idefics3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_processor_override(
2929
num_imgs: int,
3030
kwargs_on_init: bool,
3131
):
32-
"""Ensure input_processor_for_idefics3 handles num_crops properly."""
32+
"""Ensure Idefics3MultiModalProcessor handles num_crops properly."""
3333
# Same as the previous test - don't initialize mm_processor_kwargs
3434
# in this test and assume that the kwargs will be correctly expanded by
3535
# the partial when calling the custom input processor.

tests/models/multimodal/processing/test_phi3v.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_processor_override(
3030
num_imgs: int,
3131
kwargs_on_init: bool,
3232
):
33-
"""Ensure input_processor_for_phi3v handles num_crops properly."""
33+
"""Ensure Phi3VMultiModalProcessor handles num_crops properly."""
3434
# Avoid initializing CUDA early
3535
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
3636

vllm/engine/arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
665665
type=nullable_kvs,
666666
default=EngineArgs.limit_mm_per_prompt,
667667
# The default value is given in
668-
# MultiModalRegistry.init_mm_limits_per_prompt
668+
# MultiModalConfig.get_limit_per_prompt
669669
help=('For each multimodal plugin, limit how many '
670670
'input instances to allow for each prompt. '
671671
'Expects a comma-separated list of items, '

vllm/engine/llm_engine.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
get_logits_processors as get_openai_logits_processors)
3131
from vllm.executor.executor_base import ExecutorBase
3232
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
33-
PromptType, SingletonInputsAdapter)
34-
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
33+
PromptType)
34+
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
3535
from vllm.inputs.preprocess import InputPreprocessor
3636
from vllm.logger import init_logger
3737
from vllm.logits_process import get_bad_words_logits_processors
@@ -609,12 +609,7 @@ def _add_processed_request(
609609
seq_id = next(self.seq_counter)
610610
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
611611

612-
if is_encoder_decoder_inputs(processed_inputs):
613-
decoder_inputs = processed_inputs["decoder"]
614-
encoder_inputs = processed_inputs["encoder"]
615-
else:
616-
decoder_inputs = processed_inputs
617-
encoder_inputs = None
612+
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
618613

619614
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
620615
lora_request, prompt_adapter_request)
@@ -2031,15 +2026,16 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None:
20312026

20322027
def _validate_model_inputs(self, inputs: ProcessorInputs,
20332028
lora_request: Optional[LoRARequest]):
2034-
if is_encoder_decoder_inputs(inputs):
2035-
# For encoder-decoder multimodal models, the max_prompt_len
2036-
# restricts the decoder prompt length
2037-
prompt_inputs = inputs["decoder" if self.model_config.
2038-
is_multimodal_model else "encoder"]
2029+
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
2030+
2031+
# For encoder-decoder multimodal models, the max_prompt_len
2032+
# restricts the decoder prompt length
2033+
if self.model_config.is_multimodal_model:
2034+
prompt_inputs = decoder_inputs
20392035
else:
2040-
prompt_inputs = inputs
2036+
prompt_inputs = encoder_inputs or decoder_inputs
20412037

2042-
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
2038+
prompt_ids = prompt_inputs["prompt_token_ids"]
20432039

20442040
if prompt_ids is None or len(prompt_ids) == 0:
20452041
raise ValueError("Prompt cannot be empty")

vllm/inputs/parse.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
32
from collections.abc import Sequence
4-
from typing import Literal, TypedDict, Union, cast, overload
3+
from typing import Literal, Optional, TypedDict, Union, cast, overload
54

65
from typing_extensions import TypeIs
76

87
from vllm.utils import is_list_of
98

10-
from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
11-
ProcessorInputs, PromptType, SingletonPrompt, TextPrompt,
12-
TokensPrompt)
9+
from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
10+
SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
1311

1412

1513
class ParsedText(TypedDict):
@@ -110,6 +108,14 @@ def is_explicit_encoder_decoder_prompt(
110108
return isinstance(prompt, dict) and "encoder_prompt" in prompt
111109

112110

113-
def is_encoder_decoder_inputs(
114-
inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]:
115-
return "encoder" in inputs and "decoder" in inputs
111+
def split_enc_dec_inputs(
112+
inputs: ProcessorInputs,
113+
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
114+
if "encoder" in inputs and "decoder" in inputs:
115+
# NOTE: This passes pyright but not mypy
116+
return (
117+
inputs["encoder"], # type: ignore[typeddict-item]
118+
inputs["decoder"], # type: ignore[typeddict-item]
119+
)
120+
121+
return None, inputs

vllm/inputs/registry.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
resolve_mm_processor_kwargs)
2020

2121
from .data import ProcessorInputs, SingletonInputs
22-
from .parse import is_encoder_decoder_inputs
22+
from .parse import split_enc_dec_inputs
2323

2424
if TYPE_CHECKING:
2525
from vllm.config import ModelConfig
@@ -462,13 +462,11 @@ def process_input(self, model_config: "ModelConfig",
462462
**mm_processor_kwargs,
463463
)
464464

465-
if is_encoder_decoder_inputs(processed_inputs):
466-
self._ensure_mm_kwargs(processed_inputs["encoder"],
467-
mm_processor_kwargs)
468-
self._ensure_mm_kwargs(processed_inputs["decoder"],
469-
mm_processor_kwargs)
470-
else:
471-
self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)
465+
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
466+
if encoder_inputs is not None:
467+
self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs)
468+
if decoder_inputs is not None:
469+
self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs)
472470

473471
return processed_inputs
474472

vllm/model_executor/models/idefics3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def get_dummy_processor_inputs(
232232
)
233233

234234

235-
class Idefics3MultimodalProcessor(
235+
class Idefics3MultiModalProcessor(
236236
BaseMultiModalProcessor[Idefics3ProcessingInfo]):
237237

238238
def _call_hf_processor(
@@ -575,7 +575,7 @@ def forward(
575575

576576

577577
@MULTIMODAL_REGISTRY.register_processor(
578-
Idefics3MultimodalProcessor,
578+
Idefics3MultiModalProcessor,
579579
info=Idefics3ProcessingInfo,
580580
dummy_inputs=Idefics3DummyInputsBuilder)
581581
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,

vllm/v1/engine/processor.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from vllm.config import VllmConfig
88
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
99
PromptType, SingletonInputsAdapter)
10-
from vllm.inputs.parse import is_encoder_decoder_inputs
10+
from vllm.inputs.parse import split_enc_dec_inputs
1111
from vllm.inputs.preprocess import InputPreprocessor
1212
from vllm.lora.request import LoRARequest
1313
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
@@ -209,14 +209,8 @@ def process_inputs(
209209

210210
self._validate_model_inputs(processed_inputs, lora_request)
211211

212-
if is_encoder_decoder_inputs(processed_inputs):
213-
decoder_inputs = SingletonInputsAdapter(
214-
processed_inputs["decoder"])
215-
encoder_inputs = SingletonInputsAdapter(
216-
processed_inputs["encoder"])
217-
else:
218-
decoder_inputs = SingletonInputsAdapter(processed_inputs)
219-
encoder_inputs = None
212+
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
213+
decoder_inputs = SingletonInputsAdapter(decoder_inputs)
220214

221215
# TODO: Impl encoder-decoder
222216
if encoder_inputs is not None:
@@ -301,15 +295,16 @@ def process_inputs(
301295
def _validate_model_inputs(self,
302296
inputs: ProcessorInputs,
303297
lora_request: Optional[LoRARequest] = None):
304-
if is_encoder_decoder_inputs(inputs):
305-
# For encoder-decoder multimodal models, the max_prompt_len
306-
# restricts the decoder prompt length
307-
prompt_inputs = inputs["decoder" if self.model_config.
308-
is_multimodal_model else "encoder"]
298+
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
299+
300+
# For encoder-decoder multimodal models, the max_prompt_len
301+
# restricts the decoder prompt length
302+
if self.model_config.is_multimodal_model:
303+
prompt_inputs = decoder_inputs
309304
else:
310-
prompt_inputs = inputs
305+
prompt_inputs = encoder_inputs or decoder_inputs
311306

312-
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
307+
prompt_ids = prompt_inputs["prompt_token_ids"]
313308

314309
if prompt_ids is None or len(prompt_ids) == 0:
315310
raise ValueError("Prompt cannot be empty")

0 commit comments

Comments
 (0)