Skip to content

Commit b539222

Browse files
[V1] Remove input cache client (#14864)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Roger Wang <[email protected]> Co-authored-by: Roger Wang <[email protected]>
1 parent 8d6cf89 commit b539222

File tree

5 files changed

+48
-201
lines changed

5 files changed

+48
-201
lines changed

vllm/inputs/preprocess.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ def _prompt_to_llm_inputs(
379379
multi_modal_data,
380380
mm_processor_kwargs,
381381
lora_request=lora_request,
382+
return_mm_hashes=return_mm_hashes,
382383
)
383384

384385
prompt_token_ids = self._tokenize_prompt(
@@ -401,6 +402,7 @@ async def _prompt_to_llm_inputs_async(
401402
prompt: SingletonPrompt,
402403
request_id: str,
403404
lora_request: Optional[LoRARequest] = None,
405+
return_mm_hashes: bool = False,
404406
) -> SingletonInputs:
405407
"""Async version of :meth:`_extract_prompt_components`."""
406408
parsed = parse_singleton_prompt(prompt)
@@ -431,6 +433,7 @@ async def _prompt_to_llm_inputs_async(
431433
multi_modal_data,
432434
mm_processor_kwargs,
433435
lora_request=lora_request,
436+
return_mm_hashes=return_mm_hashes,
434437
)
435438

436439
return token_inputs(
@@ -452,6 +455,7 @@ async def _prompt_to_llm_inputs_async(
452455
multi_modal_data,
453456
mm_processor_kwargs,
454457
lora_request=lora_request,
458+
return_mm_hashes=return_mm_hashes,
455459
)
456460

457461
prompt_token_ids = await self._tokenize_prompt_async(
@@ -726,6 +730,7 @@ def _process_decoder_only_prompt(
726730
prompt,
727731
request_id=request_id,
728732
lora_request=lora_request,
733+
return_mm_hashes=return_mm_hashes,
729734
)
730735

731736
return self._build_decoder_only_llm_inputs(
@@ -746,6 +751,7 @@ async def _process_decoder_only_prompt_async(
746751
prompt,
747752
request_id=request_id,
748753
lora_request=lora_request,
754+
return_mm_hashes=return_mm_hashes,
749755
)
750756

751757
return self._build_decoder_only_llm_inputs(

vllm/v1/engine/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class EngineCoreRequest(
5252
# Detokenizer, but set to None when it is added to EngineCoreClient.
5353
prompt: Optional[str]
5454
prompt_token_ids: list[int]
55-
mm_inputs: Optional[list[Optional[MultiModalKwargs]]]
55+
mm_inputs: Optional[list[MultiModalKwargs]]
5656
mm_hashes: Optional[list[str]]
5757
mm_placeholders: Optional[list[PlaceholderRange]]
5858
sampling_params: SamplingParams

vllm/v1/engine/mm_input_cache.py

Lines changed: 10 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,131 +1,30 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import Any, Optional
4-
5-
from vllm.config import ModelConfig
63
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
7-
from vllm.logger import init_logger
8-
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
9-
MultiModalKwargs, MultiModalRegistry)
4+
from vllm.multimodal import MultiModalKwargs
105
from vllm.multimodal.processing import ProcessingCache
116

12-
logger = init_logger(__name__)
13-
147
# The idea of multimodal preprocessing caching is based on having a client and
158
# a server, where the client executes in the frontend process (=P0) and the
169
# server in the core process (=P1).
1710
#
1811
# -- Client:
19-
# - Apply legacy input_mapper (if one exists) to generate MultiModalKwargs.
20-
# - Perform caching of the generated MultiModalKwargs.
21-
# - This client can be deprecated once all mutimodal models migrate to use
22-
# merged preprocessor with built-in caching functionality.
12+
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
13+
# with built-in caching functionality, with mm_hash as its identifier.
2314
#
2415
# -- Server:
25-
# - Perform caching of the received MultiModalKwargs.
16+
# - MMInputCacheServer to perform caching of the received MultiModalKwargs.
2617
#
27-
# The caching for both client and server is mirrored/similar, and this allows us
18+
# The caching for both client and server is mirrored, and this allows us
2819
# to avoid the serialization of "mm_inputs" (like pixel values) between
29-
# client (=P0) and server (=P1) processes.
20+
# client (=P0) and server (=P1) processes if the mm_hash is found in the client
21+
# cache.
3022

3123
# Both Client and Server must use the same cache size
3224
# (to perform mirrored caching). This cache size is set by the environment
3325
# variable VLLM_MM_INPUT_CACHE_GIB.
3426

3527

36-
# TODO(ywang96): Deprecate this class once all multimodal models migrate to use
37-
# merged preprocessor with built-in caching functionality.
38-
class MMInputCacheClient:
39-
40-
def __init__(
41-
self,
42-
model_config: ModelConfig,
43-
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
44-
):
45-
self.model_config = model_config
46-
self.mm_registry = mm_registry
47-
self.multi_modal_input_mapper = mm_registry.create_input_mapper(
48-
model_config)
49-
self.mm_registry.init_mm_limits_per_prompt(model_config)
50-
51-
# Init cache
52-
self.use_cache = not model_config.disable_mm_preprocessor_cache
53-
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
54-
MultiModalKwargs)
55-
56-
# DEBUG: Set to None to disable
57-
self.mm_debug_cache_hit_ratio_steps = None
58-
self.mm_debug_cache_hits = 0
59-
self.mm_debug_cache_total = 0
60-
61-
def cache_hit_ratio(self, steps):
62-
total = self.mm_debug_cache_total
63-
64-
if total > 0 and total % steps == 0:
65-
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
66-
self.mm_debug_cache_hits / total)
67-
68-
# NOTE: process_inputs only supports image inputs since all multimodal
69-
# models with other modalities have migrated to use merged preprocessor.
70-
def process_inputs(
71-
self,
72-
mm_data: MultiModalDataDict,
73-
mm_hashes: Optional[list[str]],
74-
mm_processor_kwargs: Optional[dict[str, Any]],
75-
precomputed_mm_inputs: Optional[list[MultiModalKwargs]],
76-
) -> list[Optional[MultiModalKwargs]]:
77-
if precomputed_mm_inputs is None:
78-
image_inputs = mm_data["image"]
79-
if not isinstance(image_inputs, list):
80-
image_inputs = [image_inputs]
81-
num_inputs = len(image_inputs)
82-
else:
83-
num_inputs = len(precomputed_mm_inputs)
84-
85-
# Sanity
86-
if self.use_cache:
87-
assert mm_hashes is not None
88-
assert num_inputs == len(mm_hashes)
89-
90-
# Process each image input separately, so that later we can schedule
91-
# them in a fine-grained manner.
92-
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
93-
ret_inputs: list[Optional[MultiModalKwargs]] = []
94-
for input_id in range(num_inputs):
95-
if self.mm_debug_cache_hit_ratio_steps is not None:
96-
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
97-
98-
mm_input = None
99-
if self.use_cache:
100-
assert mm_hashes is not None
101-
mm_hash = mm_hashes[input_id]
102-
mm_input = self.mm_cache.get(mm_hash)
103-
104-
self.mm_debug_cache_total += 1
105-
if mm_input is None:
106-
if precomputed_mm_inputs is not None:
107-
# Reuse precomputed input (for merged preprocessor)
108-
mm_input = precomputed_mm_inputs[input_id]
109-
else:
110-
# Apply legacy input_mapper
111-
mm_input = self.multi_modal_input_mapper(
112-
{"image": [image_inputs[input_id]]},
113-
mm_processor_kwargs=mm_processor_kwargs,
114-
)
115-
116-
if self.use_cache:
117-
# Add to cache
118-
assert mm_hash is not None
119-
self.mm_cache[mm_hash] = mm_input
120-
else:
121-
self.mm_debug_cache_hits += 1
122-
mm_input = None # Avoids sending mm_input to Server
123-
124-
ret_inputs.append(mm_input)
125-
126-
return ret_inputs
127-
128-
12928
class MMInputCacheServer:
13029

13130
def __init__(self, model_config):
@@ -135,9 +34,9 @@ def __init__(self, model_config):
13534

13635
def get_and_update(
13736
self,
138-
mm_inputs: list[Optional[MultiModalKwargs]],
37+
mm_inputs: list[MultiModalKwargs],
13938
mm_hashes: list[str],
140-
) -> list[Optional[MultiModalKwargs]]:
39+
) -> list[MultiModalKwargs]:
14140
assert len(mm_inputs) == len(mm_hashes)
14241

14342
if not self.use_cache:
@@ -147,8 +46,7 @@ def get_and_update(
14746
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
14847
assert mm_hash is not None
14948
if mm_input is None:
150-
mm_input = self.mm_cache.get(mm_hash)
151-
assert mm_input is not None
49+
mm_input = self.mm_cache[mm_hash]
15250
else:
15351
self.mm_cache[mm_hash] = mm_input
15452

vllm/v1/engine/processor.py

Lines changed: 23 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
from vllm.inputs.parse import is_encoder_decoder_inputs
1212
from vllm.inputs.preprocess import InputPreprocessor
1313
from vllm.lora.request import LoRARequest
14-
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalHasher,
15-
MultiModalKwargs, MultiModalRegistry)
14+
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
15+
MultiModalRegistry)
16+
from vllm.multimodal.inputs import PlaceholderRange
1617
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
1718
from vllm.pooling_params import PoolingParams
1819
from vllm.prompt_adapter.request import PromptAdapterRequest
1920
from vllm.sampling_params import SamplingParams
2021
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
2122
from vllm.v1.engine import EngineCoreRequest
22-
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
2323
from vllm.v1.structured_output.utils import validate_structured_output_request
2424

2525

@@ -45,11 +45,6 @@ def __init__(
4545
self.input_preprocessor = InputPreprocessor(self.model_config,
4646
self.tokenizer,
4747
mm_registry)
48-
self.input_processor = input_registry.create_input_processor(
49-
self.model_config)
50-
51-
# Multi-modal (huggingface) input mapper
52-
self.mm_input_cache_client = MMInputCacheClient(self.model_config)
5348

5449
# Multi-modal hasher (for images)
5550
self.use_hash = (
@@ -171,7 +166,7 @@ def process_inputs(
171166
# 2. For multimodal models with a merged preprocessor, preprocess
172167
# multimodal data and expand prompt token ids accordingly.
173168
# 3. Apply prompt adapter to prompt token ids if one exists.
174-
preprocessed_inputs = self.input_preprocessor.preprocess(
169+
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
175170
prompt,
176171
request_id=request_id,
177172
lora_request=lora_request,
@@ -180,10 +175,6 @@ def process_inputs(
180175
)
181176
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
182177

183-
# Process prompt and prompt token ids.
184-
# Only applicable to multimodal models with legacy input processor.
185-
processed_inputs = self.input_processor(preprocessed_inputs)
186-
187178
self._validate_model_inputs(processed_inputs, lora_request)
188179

189180
if is_encoder_decoder_inputs(processed_inputs):
@@ -212,36 +203,22 @@ def process_inputs(
212203
self.tokenizer.get_lora_tokenizer(lora_request))
213204

214205
# Multimodal related.
215-
# Compute MM hashes (if enabled)
216-
mm_hashes = None
217-
if self.use_hash:
218-
# Use mm_hashes from processed inputs if the model has merged
219-
# input processor.
220-
if decoder_inputs.multi_modal_hashes:
221-
mm_hashes = decoder_inputs.multi_modal_hashes
222-
# Fallback to using MultiModalHasher directly.
223-
else:
224-
mm_hashes = MultiModalHasher.hash_prompt_mm_data(prompt)
206+
sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
207+
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
208+
sorted_mm_hashes: Optional[list[str]] = None
209+
if (decoder_mm_inputs := decoder_inputs.multi_modal_data):
210+
assert isinstance(decoder_mm_inputs, MultiModalKwargs)
225211

226-
# For merged preprocessor, mm_data is already mm_inputs
227-
precomputed_mm_inputs: Optional[list[MultiModalKwargs]] = None
228-
decoder_mm_data = decoder_inputs.multi_modal_data
229-
if isinstance(decoder_mm_data, MultiModalKwargs):
230-
# The output of merged multi-modal processor (`decoder_mm_data`)
212+
# The output of merged multi-modal processor (`decoder_mm_inputs`)
231213
# contains the kwargs for all items from all modalities.
232214
# This code separates them so that there is one set of kwargs
233215
# per item per modality.
234-
precomputed_mm_inputs = [
216+
individual_mm_inputs = [
235217
MultiModalKwargs.from_items([item])
236-
for modality in decoder_mm_data.modalities
237-
for item in decoder_mm_data.get_items(modality)
218+
for modality in decoder_mm_inputs.modalities
219+
for item in decoder_mm_inputs.get_items(modality)
238220
]
239221

240-
mm_positions = decoder_inputs.multi_modal_placeholders
241-
242-
# Last-mile processing of multimodal metadata and inputs.
243-
if mm_positions:
244-
245222
# Merge and flatten multimodal placeholders, hashes and inputs
246223
# from dictionaries to lists, and sort them by each item's position
247224
# in the input sequence.
@@ -251,41 +228,30 @@ def process_inputs(
251228
sorted_mm_positions,
252229
sorted_mm_hashes,
253230
) = merge_and_sort_multimodal_metadata(
254-
mm_positions,
255-
mm_hashes,
231+
decoder_inputs.multi_modal_placeholders,
232+
decoder_inputs.multi_modal_hashes if self.use_hash else None,
256233
)
257234

258235
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
259-
# modalities involved AND the model supports merged input processor.
260-
if len(sorted_modalities) > 1 and precomputed_mm_inputs:
261-
236+
# modalities involved.
237+
if len(sorted_modalities) > 1:
262238
modality_order_dict = {
263239
modality: order
264240
for order, modality in enumerate(sorted_modalities)
265241
}
266242

267243
# Sanity check to make sure each multimodal input has only one
268244
# modality key.
269-
for mm_input in precomputed_mm_inputs:
245+
for mm_input in individual_mm_inputs:
270246
assert len(mm_input.modalities) == 1
271247

272-
# Sort MultiModalKwags to match sorted_mm_positions
273-
precomputed_mm_inputs = sorted(
274-
precomputed_mm_inputs,
248+
# Sort MultiModalKwargs to match sorted_mm_positions
249+
sorted_mm_inputs = sorted(
250+
individual_mm_inputs,
275251
key=lambda mm_input: modality_order_dict[list(
276252
mm_input.modalities)[0]])
277-
278-
# Apply mm input cache update and legacy input mapper if one exists.
279-
sorted_mm_inputs = self.mm_input_cache_client.process_inputs(
280-
mm_data=decoder_mm_data,
281-
mm_hashes=sorted_mm_hashes,
282-
mm_processor_kwargs=decoder_inputs.mm_processor_kwargs,
283-
precomputed_mm_inputs=precomputed_mm_inputs,
284-
)
285-
else:
286-
sorted_mm_inputs = None
287-
sorted_mm_hashes = None
288-
sorted_mm_positions = None
253+
else:
254+
sorted_mm_inputs = individual_mm_inputs
289255

290256
return EngineCoreRequest(
291257
request_id=request_id,

0 commit comments

Comments
 (0)