1
1
# SPDX-License-Identifier: Apache-2.0
2
2
3
3
import time
4
- from collections .abc import Mapping
4
+ from collections .abc import Mapping , Sequence
5
5
from typing import Literal , Optional , Union
6
6
7
7
from vllm .config import VllmConfig
19
19
from vllm .sampling_params import SamplingParams
20
20
from vllm .transformers_utils .tokenizer_group import BaseTokenizerGroup
21
21
from vllm .v1 .engine import EngineCoreRequest
22
+ from vllm .v1 .engine .mm_input_cache import MirroredProcessingCache
22
23
from vllm .v1 .structured_output .backend_guidance import (
23
24
validate_guidance_grammar )
24
25
from vllm .v1 .structured_output .utils import (
@@ -47,6 +48,8 @@ def __init__(
47
48
self .tokenizer ,
48
49
mm_registry )
49
50
51
+ self .mm_input_cache_client = MirroredProcessingCache (self .model_config )
52
+
50
53
# Multi-modal hasher (for images)
51
54
self .use_hash = (
52
55
not self .model_config .disable_mm_preprocessor_cache ) or \
@@ -231,7 +234,7 @@ def process_inputs(
231
234
self .tokenizer .get_lora_tokenizer (lora_request ))
232
235
233
236
# Multimodal related.
234
- sorted_mm_inputs : Optional [list [ MultiModalKwargs ]] = None
237
+ sorted_mm_inputs : Optional [Sequence [ Optional [ MultiModalKwargs ] ]] = None
235
238
sorted_mm_positions : Optional [list [PlaceholderRange ]] = None
236
239
sorted_mm_hashes : Optional [list [str ]] = None
237
240
if decoder_inputs ["type" ] == "multimodal" :
@@ -256,20 +259,28 @@ def process_inputs(
256
259
# are multiple modalities.
257
260
unique_modalities = set (sorted_item_modalities )
258
261
if len (unique_modalities ) > 1 :
259
- sorted_mm_inputs = []
262
+ orig_sorted_mm_inputs = []
260
263
used_indices = {modality : 0 for modality in unique_modalities }
264
+
261
265
for modality in sorted_item_modalities :
262
266
items = decoder_mm_inputs .get_items (modality )
263
267
item = items [used_indices [modality ]]
264
- sorted_mm_inputs .append (MultiModalKwargs .from_items ([item
265
- ]))
268
+
269
+ orig_sorted_mm_inputs .append (
270
+ MultiModalKwargs .from_items ([item ]))
266
271
used_indices [modality ] += 1
267
272
else :
268
- sorted_mm_inputs = [
273
+ orig_sorted_mm_inputs = [
269
274
MultiModalKwargs .from_items ([item ]) for item in
270
275
decoder_mm_inputs .get_items (sorted_item_modalities [0 ])
271
276
]
272
277
278
+ if sorted_mm_hashes is not None :
279
+ sorted_mm_inputs = self .mm_input_cache_client .get_and_update_p0 (
280
+ orig_sorted_mm_inputs , sorted_mm_hashes )
281
+ else :
282
+ sorted_mm_inputs = orig_sorted_mm_inputs
283
+
273
284
return EngineCoreRequest (
274
285
request_id = request_id ,
275
286
prompt = decoder_inputs .get ("prompt" ),
0 commit comments