1
1
# SPDX-License-Identifier: Apache-2.0
2
2
3
- from typing import Any , Optional
4
-
5
- from vllm .config import ModelConfig
6
3
from vllm .envs import VLLM_MM_INPUT_CACHE_GIB
7
- from vllm .logger import init_logger
8
- from vllm .multimodal import (MULTIMODAL_REGISTRY , MultiModalDataDict ,
9
- MultiModalKwargs , MultiModalRegistry )
4
+ from vllm .multimodal import MultiModalKwargs
10
5
from vllm .multimodal .processing import ProcessingCache
11
6
12
- logger = init_logger (__name__ )
13
-
14
7
# The idea of multimodal preprocessing caching is based on having a client and
15
8
# a server, where the client executes in the frontend process (=P0) and the
16
9
# server in the core process (=P1).
17
10
#
18
11
# -- Client:
19
- # - Apply legacy input_mapper (if one exists) to generate MultiModalKwargs.
20
- # - Perform caching of the generated MultiModalKwargs.
21
- # - This client can be deprecated once all mutimodal models migrate to use
22
- # merged preprocessor with built-in caching functionality.
12
+ # - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
13
+ # with built-in caching functionality, with mm_hash as its identifier.
23
14
#
24
15
# -- Server:
25
- # - Perform caching of the received MultiModalKwargs.
16
+ # - MMInputCacheServer to perform caching of the received MultiModalKwargs.
26
17
#
27
- # The caching for both client and server is mirrored/similar , and this allows us
18
+ # The caching for both client and server is mirrored, and this allows us
28
19
# to avoid the serialization of "mm_inputs" (like pixel values) between
29
- # client (=P0) and server (=P1) processes.
20
+ # client (=P0) and server (=P1) processes if the mm_hash is found in the client
21
+ # cache.
30
22
31
23
# Both Client and Server must use the same cache size
32
24
# (to perform mirrored caching). This cache size is set by the environment
33
25
# variable VLLM_MM_INPUT_CACHE_GIB.
34
26
35
27
36
- # TODO(ywang96): Deprecate this class once all multimodal models migrate to use
37
- # merged preprocessor with built-in caching functionality.
38
- class MMInputCacheClient :
39
-
40
- def __init__ (
41
- self ,
42
- model_config : ModelConfig ,
43
- mm_registry : MultiModalRegistry = MULTIMODAL_REGISTRY ,
44
- ):
45
- self .model_config = model_config
46
- self .mm_registry = mm_registry
47
- self .multi_modal_input_mapper = mm_registry .create_input_mapper (
48
- model_config )
49
- self .mm_registry .init_mm_limits_per_prompt (model_config )
50
-
51
- # Init cache
52
- self .use_cache = not model_config .disable_mm_preprocessor_cache
53
- self .mm_cache = ProcessingCache .get_lru_cache (VLLM_MM_INPUT_CACHE_GIB ,
54
- MultiModalKwargs )
55
-
56
- # DEBUG: Set to None to disable
57
- self .mm_debug_cache_hit_ratio_steps = None
58
- self .mm_debug_cache_hits = 0
59
- self .mm_debug_cache_total = 0
60
-
61
- def cache_hit_ratio (self , steps ):
62
- total = self .mm_debug_cache_total
63
-
64
- if total > 0 and total % steps == 0 :
65
- logger .debug ("MMInputMapper: cache_hit_ratio = %.2f " ,
66
- self .mm_debug_cache_hits / total )
67
-
68
- # NOTE: process_inputs only supports image inputs since all multimodal
69
- # models with other modalities have migrated to use merged preprocessor.
70
- def process_inputs (
71
- self ,
72
- mm_data : MultiModalDataDict ,
73
- mm_hashes : Optional [list [str ]],
74
- mm_processor_kwargs : Optional [dict [str , Any ]],
75
- precomputed_mm_inputs : Optional [list [MultiModalKwargs ]],
76
- ) -> list [Optional [MultiModalKwargs ]]:
77
- if precomputed_mm_inputs is None :
78
- image_inputs = mm_data ["image" ]
79
- if not isinstance (image_inputs , list ):
80
- image_inputs = [image_inputs ]
81
- num_inputs = len (image_inputs )
82
- else :
83
- num_inputs = len (precomputed_mm_inputs )
84
-
85
- # Sanity
86
- if self .use_cache :
87
- assert mm_hashes is not None
88
- assert num_inputs == len (mm_hashes )
89
-
90
- # Process each image input separately, so that later we can schedule
91
- # them in a fine-grained manner.
92
- # Apply caching (if enabled) and reuse precomputed inputs (if provided)
93
- ret_inputs : list [Optional [MultiModalKwargs ]] = []
94
- for input_id in range (num_inputs ):
95
- if self .mm_debug_cache_hit_ratio_steps is not None :
96
- self .cache_hit_ratio (self .mm_debug_cache_hit_ratio_steps )
97
-
98
- mm_input = None
99
- if self .use_cache :
100
- assert mm_hashes is not None
101
- mm_hash = mm_hashes [input_id ]
102
- mm_input = self .mm_cache .get (mm_hash )
103
-
104
- self .mm_debug_cache_total += 1
105
- if mm_input is None :
106
- if precomputed_mm_inputs is not None :
107
- # Reuse precomputed input (for merged preprocessor)
108
- mm_input = precomputed_mm_inputs [input_id ]
109
- else :
110
- # Apply legacy input_mapper
111
- mm_input = self .multi_modal_input_mapper (
112
- {"image" : [image_inputs [input_id ]]},
113
- mm_processor_kwargs = mm_processor_kwargs ,
114
- )
115
-
116
- if self .use_cache :
117
- # Add to cache
118
- assert mm_hash is not None
119
- self .mm_cache [mm_hash ] = mm_input
120
- else :
121
- self .mm_debug_cache_hits += 1
122
- mm_input = None # Avoids sending mm_input to Server
123
-
124
- ret_inputs .append (mm_input )
125
-
126
- return ret_inputs
127
-
128
-
129
28
class MMInputCacheServer :
130
29
131
30
def __init__ (self , model_config ):
@@ -135,9 +34,9 @@ def __init__(self, model_config):
135
34
136
35
def get_and_update (
137
36
self ,
138
- mm_inputs : list [Optional [ MultiModalKwargs ] ],
37
+ mm_inputs : list [MultiModalKwargs ],
139
38
mm_hashes : list [str ],
140
- ) -> list [Optional [ MultiModalKwargs ] ]:
39
+ ) -> list [MultiModalKwargs ]:
141
40
assert len (mm_inputs ) == len (mm_hashes )
142
41
143
42
if not self .use_cache :
@@ -147,8 +46,7 @@ def get_and_update(
147
46
for mm_input , mm_hash in zip (mm_inputs , mm_hashes ):
148
47
assert mm_hash is not None
149
48
if mm_input is None :
150
- mm_input = self .mm_cache .get (mm_hash )
151
- assert mm_input is not None
49
+ mm_input = self .mm_cache [mm_hash ]
152
50
else :
153
51
self .mm_cache [mm_hash ] = mm_input
154
52
0 commit comments