Skip to content

Commit b86e411

Browse files
DarkLight1337zRzRzRzRzRzRzR
authored andcommitted
[Bugfix] Avoid transferring cached multi-modal items from P0 to P1 (vllm-project#16273)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: zRzRzRzRzRzRzR <[email protected]>
1 parent f9988ba commit b86e411

File tree

5 files changed

+65
-22
lines changed

5 files changed

+65
-22
lines changed

vllm/v1/engine/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import enum
44
import time
5+
from collections.abc import Sequence
56
from typing import Any, Optional, Union
67

78
import msgspec
@@ -52,7 +53,7 @@ class EngineCoreRequest(
5253
# Detokenizer, but set to None when it is added to EngineCoreClient.
5354
prompt: Optional[str]
5455
prompt_token_ids: list[int]
55-
mm_inputs: Optional[list[MultiModalKwargs]]
56+
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
5657
mm_hashes: Optional[list[str]]
5758
mm_placeholders: Optional[list[PlaceholderRange]]
5859
sampling_params: SamplingParams

vllm/v1/engine/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
3232
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
3333
EngineCoreRequestType, UtilityOutput)
34-
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
34+
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
3535
from vllm.v1.executor.abstract import Executor
3636
from vllm.v1.kv_cache_interface import KVCacheConfig
3737
from vllm.v1.outputs import ModelRunnerOutput
@@ -105,7 +105,7 @@ def __init__(
105105
)
106106

107107
# Setup MM Input Mapper.
108-
self.mm_input_cache_server = MMInputCacheServer(
108+
self.mm_input_cache_server = MirroredProcessingCache(
109109
vllm_config.model_config)
110110

111111
# Setup batch queue for pipeline parallelism.
@@ -173,7 +173,7 @@ def add_request(self, request: EngineCoreRequest):
173173
# anything that has a hash must have a HIT cache entry here
174174
# as well.
175175
assert request.mm_inputs is not None
176-
request.mm_inputs = self.mm_input_cache_server.get_and_update(
176+
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
177177
request.mm_inputs, request.mm_hashes)
178178

179179
req = Request.from_engine_core_request(request)

vllm/v1/engine/mm_input_cache.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
from collections.abc import Sequence
3+
from typing import Optional
24

35
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
46
from vllm.multimodal import MultiModalKwargs
57
from vllm.multimodal.processing import ProcessingCache
8+
from vllm.utils import is_list_of
69

710
# The idea of multimodal preprocessing caching is based on having a client and
811
# a server, where the client executes in the frontend process (=P0) and the
@@ -11,9 +14,11 @@
1114
# -- Client:
1215
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
1316
# with built-in caching functionality, with mm_hash as its identifier.
17+
# - MirroredProcessingCache to keep track of the cached entries and
18+
# determine whether to send the MultiModalKwargs to P1.
1419
#
1520
# -- Server:
16-
# - MMInputCacheServer to perform caching of the received MultiModalKwargs.
21+
# - MirroredProcessingCache to store the MultiModalKwargs from P0.
1722
#
1823
# The caching for both client and server is mirrored, and this allows us
1924
# to avoid the serialization of "mm_inputs" (like pixel values) between
@@ -25,26 +30,48 @@
2530
# variable VLLM_MM_INPUT_CACHE_GIB.
2631

2732

28-
class MMInputCacheServer:
33+
class MirroredProcessingCache:
2934

3035
def __init__(self, model_config):
3136
self.use_cache = not model_config.disable_mm_preprocessor_cache
3237
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
3338
MultiModalKwargs)
3439

35-
def get_and_update(
40+
def get_and_update_p0(
3641
self,
37-
mm_inputs: list[MultiModalKwargs],
42+
mm_inputs: Sequence[MultiModalKwargs],
3843
mm_hashes: list[str],
39-
) -> list[MultiModalKwargs]:
44+
) -> Sequence[Optional[MultiModalKwargs]]:
4045
assert len(mm_inputs) == len(mm_hashes)
4146

4247
if not self.use_cache:
48+
assert is_list_of(mm_inputs, MultiModalKwargs)
4349
return mm_inputs
4450

45-
full_mm_inputs = []
51+
full_mm_inputs = list[Optional[MultiModalKwargs]]()
52+
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
53+
if mm_hash in self.mm_cache:
54+
mm_input = None
55+
else:
56+
self.mm_cache[mm_hash] = mm_input
57+
58+
full_mm_inputs.append(mm_input)
59+
60+
return full_mm_inputs
61+
62+
def get_and_update_p1(
63+
self,
64+
mm_inputs: Sequence[Optional[MultiModalKwargs]],
65+
mm_hashes: list[str],
66+
) -> Sequence[MultiModalKwargs]:
67+
assert len(mm_inputs) == len(mm_hashes)
68+
69+
if not self.use_cache:
70+
assert is_list_of(mm_inputs, MultiModalKwargs)
71+
return mm_inputs
72+
73+
full_mm_inputs = list[MultiModalKwargs]()
4674
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
47-
assert mm_hash is not None
4875
if mm_input is None:
4976
mm_input = self.mm_cache[mm_hash]
5077
else:

vllm/v1/engine/processor.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import time
4-
from collections.abc import Mapping
4+
from collections.abc import Mapping, Sequence
55
from typing import Literal, Optional, Union
66

77
from vllm.config import VllmConfig
@@ -19,6 +19,7 @@
1919
from vllm.sampling_params import SamplingParams
2020
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
2121
from vllm.v1.engine import EngineCoreRequest
22+
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
2223
from vllm.v1.structured_output.backend_guidance import (
2324
validate_guidance_grammar)
2425
from vllm.v1.structured_output.utils import (
@@ -47,6 +48,8 @@ def __init__(
4748
self.tokenizer,
4849
mm_registry)
4950

51+
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
52+
5053
# Multi-modal hasher (for images)
5154
self.use_hash = (
5255
not self.model_config.disable_mm_preprocessor_cache) or \
@@ -231,7 +234,7 @@ def process_inputs(
231234
self.tokenizer.get_lora_tokenizer(lora_request))
232235

233236
# Multimodal related.
234-
sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
237+
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
235238
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
236239
sorted_mm_hashes: Optional[list[str]] = None
237240
if decoder_inputs["type"] == "multimodal":
@@ -256,20 +259,28 @@ def process_inputs(
256259
# are multiple modalities.
257260
unique_modalities = set(sorted_item_modalities)
258261
if len(unique_modalities) > 1:
259-
sorted_mm_inputs = []
262+
orig_sorted_mm_inputs = []
260263
used_indices = {modality: 0 for modality in unique_modalities}
264+
261265
for modality in sorted_item_modalities:
262266
items = decoder_mm_inputs.get_items(modality)
263267
item = items[used_indices[modality]]
264-
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
265-
]))
268+
269+
orig_sorted_mm_inputs.append(
270+
MultiModalKwargs.from_items([item]))
266271
used_indices[modality] += 1
267272
else:
268-
sorted_mm_inputs = [
273+
orig_sorted_mm_inputs = [
269274
MultiModalKwargs.from_items([item]) for item in
270275
decoder_mm_inputs.get_items(sorted_item_modalities[0])
271276
]
272277

278+
if sorted_mm_hashes is not None:
279+
sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
280+
orig_sorted_mm_inputs, sorted_mm_hashes)
281+
else:
282+
sorted_mm_inputs = orig_sorted_mm_inputs
283+
273284
return EngineCoreRequest(
274285
request_id=request_id,
275286
prompt=decoder_inputs.get("prompt"),

vllm/v1/request.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@
33
import enum
44
from typing import TYPE_CHECKING, Optional, Union
55

6+
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
67
from vllm.sampling_params import SamplingParams
8+
from vllm.utils import is_list_of
79
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
810
EngineCoreRequest, FinishReason)
911
from vllm.v1.structured_output.request import StructuredOutputRequest
1012
from vllm.v1.utils import ConstantList
1113

1214
if TYPE_CHECKING:
13-
1415
from vllm.lora.request import LoRARequest
15-
from vllm.multimodal import MultiModalKwargs
16-
from vllm.multimodal.inputs import PlaceholderRange
1716

1817

1918
class Request:
@@ -23,9 +22,9 @@ def __init__(
2322
request_id: str,
2423
prompt: Optional[str],
2524
prompt_token_ids: list[int],
26-
multi_modal_inputs: Optional[list["MultiModalKwargs"]],
25+
multi_modal_inputs: Optional[list[MultiModalKwargs]],
2726
multi_modal_hashes: Optional[list[str]],
28-
multi_modal_placeholders: Optional[list["PlaceholderRange"]],
27+
multi_modal_placeholders: Optional[list[PlaceholderRange]],
2928
sampling_params: SamplingParams,
3029
eos_token_id: Optional[int],
3130
arrival_time: float,
@@ -75,6 +74,11 @@ def __init__(
7574

7675
@classmethod
7776
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
77+
if request.mm_inputs is not None:
78+
assert isinstance(request.mm_inputs, list)
79+
assert is_list_of(request.mm_inputs, MultiModalKwargs), (
80+
"mm_inputs was not updated in EngineCore.add_request")
81+
7882
return cls(
7983
request_id=request.request_id,
8084
prompt=request.prompt,

0 commit comments

Comments
 (0)