Skip to content

Commit 166b551

Browse files
committed
[Core] Remove legacy input mapper/processor from V0
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 2d9045f commit 166b551

21 files changed

+156
-1351
lines changed

vllm/core/scheduler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1596,7 +1596,6 @@ def schedule(
15961596
multi_modal_placeholders=(
15971597
seq_group.multi_modal_placeholders
15981598
if scheduler_outputs.num_prefill_groups > 0 else None),
1599-
mm_processor_kwargs=seq_group.mm_processor_kwargs,
16001599
prompt_adapter_request=seq_group.prompt_adapter_request,
16011600
)
16021601
else:

vllm/engine/async_llm_engine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,12 +490,11 @@ async def add_request_async(
490490
tokenizer = await self.get_tokenizer_async(lora_request)
491491
self._validate_token_prompt(prompt, tokenizer=tokenizer)
492492

493-
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
493+
processed_inputs = await self.input_preprocessor.preprocess_async(
494494
prompt,
495495
lora_request=lora_request,
496496
prompt_adapter_request=prompt_adapter_request,
497497
)
498-
processed_inputs = self.input_processor(preprocessed_inputs)
499498

500499
if isinstance(params, SamplingParams) and \
501500
params.guided_decoding is not None:

vllm/engine/llm_engine.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
from vllm.entrypoints.openai.logits_processors import (
3030
get_logits_processors as get_openai_logits_processors)
3131
from vllm.executor.executor_base import ExecutorBase
32-
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
33-
PromptType)
32+
from vllm.inputs import ProcessorInputs, PromptType
3433
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
3534
from vllm.inputs.preprocess import InputPreprocessor
3635
from vllm.logger import init_logger
@@ -212,7 +211,6 @@ def __init__(
212211
log_stats: bool,
213212
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
214213
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
215-
input_registry: InputRegistry = INPUT_REGISTRY,
216214
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
217215
use_cached_outputs: bool = False,
218216
) -> None:
@@ -273,11 +271,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
273271
self.tokenizer,
274272
mm_registry)
275273

276-
self.input_registry = input_registry
277-
self.input_processor = input_registry.create_input_processor(
278-
self.model_config)
279-
280-
self.model_executor = executor_class(vllm_config=vllm_config, )
274+
self.model_executor = executor_class(vllm_config=vllm_config)
281275

282276
if self.model_config.runner_type != "pooling":
283277
self._initialize_kv_caches()
@@ -776,12 +770,11 @@ def add_request(
776770
prompt,
777771
tokenizer=self.get_tokenizer(lora_request=lora_request))
778772

779-
preprocessed_inputs = self.input_preprocessor.preprocess(
773+
processed_inputs = self.input_preprocessor.preprocess(
780774
prompt,
781775
lora_request=lora_request,
782776
prompt_adapter_request=prompt_adapter_request,
783777
)
784-
processed_inputs = self.input_processor(preprocessed_inputs)
785778

786779
self._add_processed_request(
787780
request_id=request_id,

vllm/inputs/__init__.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22

33
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
44
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
5-
SingletonInputs, SingletonInputsAdapter, SingletonPrompt,
6-
TextPrompt, TokenInputs, TokensPrompt,
7-
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
8-
token_inputs, zip_enc_dec_prompts)
9-
from .registry import (DummyData, InputContext, InputProcessingContext,
10-
InputRegistry)
5+
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
6+
TokensPrompt, build_explicit_enc_dec_prompt,
7+
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
8+
from .registry import InputContext, InputProcessingContext, InputRegistry
119

1210
INPUT_REGISTRY = InputRegistry()
1311
"""
@@ -27,13 +25,11 @@
2725
"EncoderDecoderInputs",
2826
"ProcessorInputs",
2927
"SingletonInputs",
30-
"SingletonInputsAdapter",
3128
"build_explicit_enc_dec_prompt",
3229
"to_enc_dec_tuple_list",
3330
"zip_enc_dec_prompts",
34-
"INPUT_REGISTRY",
35-
"DummyData",
31+
"INPUT_REGISTRY", # DEPRECATED
3632
"InputContext",
3733
"InputProcessingContext",
38-
"InputRegistry",
34+
"InputRegistry", # DEPRECATED
3935
]

vllm/inputs/data.py

Lines changed: 2 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
32
from collections.abc import Iterable
4-
from dataclasses import dataclass
5-
from functools import cached_property
63
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
74

8-
import torch
9-
from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never
5+
from typing_extensions import NotRequired, TypedDict, TypeVar
106

117
if TYPE_CHECKING:
12-
from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs,
13-
MultiModalPlaceholderDict)
14-
from vllm.multimodal.inputs import MultiModalInputs
8+
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
159

1610

1711
class TextPrompt(TypedDict):
@@ -153,22 +147,6 @@ class TokenInputs(TypedDict):
153147
if the model supports it.
154148
"""
155149

156-
multi_modal_inputs: NotRequired["MultiModalKwargs"]
157-
"""
158-
Optional multi-modal inputs to pass to the model,
159-
if the model supports it.
160-
"""
161-
162-
multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"]
163-
"""
164-
Placeholder ranges for the multi-modal data.
165-
"""
166-
167-
multi_modal_hashes: NotRequired[list[str]]
168-
"""
169-
The hashes of the multi-modal data.
170-
"""
171-
172150
mm_processor_kwargs: NotRequired[dict[str, Any]]
173151
"""
174152
Optional multi-modal processor kwargs to be forwarded to the
@@ -183,9 +161,6 @@ def token_inputs(
183161
token_type_ids: Optional[list[int]] = None,
184162
prompt: Optional[str] = None,
185163
multi_modal_data: Optional["MultiModalDataDict"] = None,
186-
multi_modal_inputs: Optional["MultiModalKwargs"] = None,
187-
multi_modal_hashes: Optional[list[str]] = None,
188-
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
189164
mm_processor_kwargs: Optional[dict[str, Any]] = None,
190165
) -> TokenInputs:
191166
"""Construct :class:`TokenInputs` from optional values."""
@@ -197,12 +172,6 @@ def token_inputs(
197172
inputs["token_type_ids"] = token_type_ids
198173
if multi_modal_data is not None:
199174
inputs["multi_modal_data"] = multi_modal_data
200-
if multi_modal_inputs is not None:
201-
inputs["multi_modal_inputs"] = multi_modal_inputs
202-
if multi_modal_hashes is not None:
203-
inputs["multi_modal_hashes"] = multi_modal_hashes
204-
if multi_modal_placeholders is not None:
205-
inputs["multi_modal_placeholders"] = multi_modal_placeholders
206175
if mm_processor_kwargs is not None:
207176
inputs["mm_processor_kwargs"] = mm_processor_kwargs
208177

@@ -237,112 +206,6 @@ class EncoderDecoderInputs(TypedDict):
237206
:class:`vllm.sequence.Sequence`.
238207
"""
239208

240-
241-
@dataclass
242-
class SingletonInputsAdapter:
243-
"""
244-
Unified interface to access the components of :class:`SingletonInputs`.
245-
"""
246-
inputs: SingletonInputs
247-
248-
@cached_property
249-
def prompt(self) -> Optional[str]:
250-
inputs = self.inputs
251-
252-
if inputs["type"] == "token" or inputs["type"] == "multimodal":
253-
return inputs.get("prompt")
254-
255-
assert_never(inputs) # type: ignore[arg-type]
256-
257-
@cached_property
258-
def prompt_token_ids(self) -> list[int]:
259-
inputs = self.inputs
260-
261-
if inputs["type"] == "token" or inputs["type"] == "multimodal":
262-
return inputs.get("prompt_token_ids", [])
263-
264-
assert_never(inputs) # type: ignore[arg-type]
265-
266-
@cached_property
267-
def token_type_ids(self) -> list[int]:
268-
inputs = self.inputs
269-
270-
if inputs["type"] == "token" or inputs["type"] == "multimodal":
271-
return inputs.get("token_type_ids", [])
272-
273-
assert_never(inputs) # type: ignore[arg-type]
274-
275-
@cached_property
276-
def prompt_embeds(self) -> Optional[torch.Tensor]:
277-
inputs = self.inputs
278-
279-
if inputs["type"] == "token" or inputs["type"] == "multimodal":
280-
return None
281-
282-
assert_never(inputs) # type: ignore[arg-type]
283-
284-
@cached_property
285-
def multi_modal_data(self) -> "MultiModalDataDict":
286-
inputs = self.inputs
287-
288-
if inputs["type"] == "token":
289-
return inputs.get("multi_modal_data", {})
290-
291-
if inputs["type"] == "multimodal":
292-
return inputs.get("mm_kwargs", {})
293-
294-
assert_never(inputs) # type: ignore[arg-type]
295-
296-
@cached_property
297-
def multi_modal_inputs(self) -> Union[dict, "MultiModalKwargs"]:
298-
inputs = self.inputs
299-
300-
if inputs["type"] == "token":
301-
return inputs.get("multi_modal_inputs", {})
302-
303-
if inputs["type"] == "multimodal":
304-
return inputs.get("mm_kwargs", {})
305-
306-
assert_never(inputs) # type: ignore[arg-type]
307-
308-
@cached_property
309-
def multi_modal_hashes(self) -> list[str]:
310-
inputs = self.inputs
311-
312-
if inputs["type"] == "token":
313-
return inputs.get("multi_modal_hashes", [])
314-
315-
if inputs["type"] == "multimodal":
316-
# only the case when we use MultiModalInputs
317-
return inputs.get("mm_hashes", []) # type: ignore[return-value]
318-
319-
assert_never(inputs) # type: ignore[arg-type]
320-
321-
@cached_property
322-
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
323-
inputs = self.inputs
324-
325-
if inputs["type"] == "token":
326-
return inputs.get("multi_modal_placeholders", {})
327-
328-
if inputs["type"] == "multimodal":
329-
return inputs.get("mm_placeholders", {})
330-
331-
assert_never(inputs) # type: ignore[arg-type]
332-
333-
@cached_property
334-
def mm_processor_kwargs(self) -> dict[str, Any]:
335-
inputs = self.inputs
336-
337-
if inputs["type"] == "token":
338-
return inputs.get("mm_processor_kwargs", {})
339-
340-
if inputs["type"] == "multimodal":
341-
return {}
342-
343-
assert_never(inputs) # type: ignore[arg-type]
344-
345-
346209
ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
347210
"""
348211
The inputs to :data:`vllm.inputs.InputProcessor`.

vllm/inputs/preprocess.py

Lines changed: 10 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -223,28 +223,6 @@ async def _tokenize_prompt_async(
223223
lora_request=lora_request,
224224
add_special_tokens=add_special_tokens)
225225

226-
def _can_process_multimodal(self) -> bool:
227-
model_config = self.model_config
228-
229-
if not model_config.is_multimodal_model:
230-
raise ValueError("Your model does not support multi-modal inputs")
231-
232-
# Interim measure so we can handle models that have yet to be
233-
# updated to use the new multi-modal processor
234-
can_process_multimodal = self.mm_registry.has_processor(model_config)
235-
if not can_process_multimodal:
236-
from vllm.model_executor.models.registry import _VLLM_MODELS
237-
if not any(arch in _VLLM_MODELS
238-
for arch in model_config.architectures):
239-
logger.warning_once(
240-
"Your model uses the legacy input pipeline, which will be "
241-
"removed in an upcoming release. "
242-
"Please upgrade to the new multi-modal processing pipeline "
243-
"(https://docs.vllm.ai/en/latest/design/mm_processing.html)"
244-
)
245-
246-
return can_process_multimodal
247-
248226
def _process_multimodal(
249227
self,
250228
prompt: Union[str, list[int]],
@@ -258,8 +236,7 @@ def _process_multimodal(
258236
returning the corresponding token IDs and metadata.
259237
"""
260238
# At the moment on model (PrithviGeoSpatialMAE) requires to be
261-
# initialized without a tokenizer while using also multi-modal
262-
# input.
239+
# initialized without a tokenizer while using also multi-modal input
263240
if not self.tokenizer:
264241
tokenizer = object() # Dummy
265242
else:
@@ -285,8 +262,7 @@ async def _process_multimodal_async(
285262
) -> MultiModalInputs:
286263
"""Async version of :meth:`_process_multimodal`."""
287264
# At the moment on model (PrithviGeoSpatialMAE) requires to be
288-
# initialized without a tokenizer while using also multi-modal
289-
# input.
265+
# initialized without a tokenizer while using also multi-modal input
290266
if not self.tokenizer:
291267
tokenizer = object() # Dummy
292268
else:
@@ -343,7 +319,7 @@ def _prompt_to_llm_inputs(
343319
multi_modal_data = tokens_content.get("multi_modal_data")
344320
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
345321

346-
if multi_modal_data is not None and self._can_process_multimodal():
322+
if multi_modal_data is not None:
347323
return self._process_multimodal(
348324
prompt_token_ids,
349325
multi_modal_data,
@@ -366,7 +342,7 @@ def _prompt_to_llm_inputs(
366342
multi_modal_data = text_content.get("multi_modal_data")
367343
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
368344

369-
if multi_modal_data is not None and self._can_process_multimodal():
345+
if multi_modal_data is not None:
370346
return self._process_multimodal(
371347
prompt_text,
372348
multi_modal_data,
@@ -417,7 +393,7 @@ async def _prompt_to_llm_inputs_async(
417393
multi_modal_data = tokens_content.get("multi_modal_data")
418394
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
419395

420-
if multi_modal_data is not None and self._can_process_multimodal():
396+
if multi_modal_data is not None:
421397
return await self._process_multimodal_async(
422398
prompt_token_ids,
423399
multi_modal_data,
@@ -439,7 +415,7 @@ async def _prompt_to_llm_inputs_async(
439415
multi_modal_data = text_content.get("multi_modal_data")
440416
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
441417

442-
if multi_modal_data is not None and self._can_process_multimodal():
418+
if multi_modal_data is not None:
443419
return await self._process_multimodal_async(
444420
prompt_text,
445421
multi_modal_data,
@@ -594,15 +570,13 @@ def _process_encoder_decoder_prompt(
594570
decoder_inputs = self._prompt_to_llm_inputs(decoder_input)
595571
# For multimodal model, override decoder prompt from processor
596572
# with explicit decoder prompt.
597-
if self.model_config.is_multimodal_model and (
598-
self._can_process_multimodal()):
573+
if self.model_config.is_multimodal_model:
599574
encoder_inputs, decoder_inputs = (
600575
self._separate_enc_dec_inputs_from_mm_processor_outputs(
601576
encoder_inputs, decoder_inputs))
602577
else:
603578
inputs = self._prompt_to_llm_inputs(prompt)
604-
if self.model_config.is_multimodal_model and (
605-
self._can_process_multimodal()):
579+
if self.model_config.is_multimodal_model:
606580
# Encoder-Decoder Multimodal model
607581
encoder_inputs, decoder_inputs = (
608582
self._separate_enc_dec_inputs_from_mm_processor_outputs(
@@ -637,15 +611,13 @@ async def _process_encoder_decoder_prompt_async(
637611

638612
# For multimodal model, override decoder prompt from processor
639613
# with explicit decoder prompt.
640-
if self.model_config.is_multimodal_model and (
641-
self._can_process_multimodal()):
614+
if self.model_config.is_multimodal_model:
642615
encoder_inputs, decoder_inputs = (
643616
self._separate_enc_dec_inputs_from_mm_processor_outputs(
644617
encoder_inputs, decoder_inputs))
645618
else:
646619
inputs = await self._prompt_to_llm_inputs_async(prompt)
647-
if self.model_config.is_multimodal_model and (
648-
self._can_process_multimodal()):
620+
if self.model_config.is_multimodal_model:
649621
# Encoder-Decoder Multimodal model
650622
encoder_inputs, decoder_inputs = (
651623
self._separate_enc_dec_inputs_from_mm_processor_outputs(

0 commit comments

Comments
 (0)