From b3dbf5c4bbcae80d2d0856204aa2e07b5f181151 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 8 Apr 2025 15:37:51 +0000 Subject: [PATCH 1/3] [Bugfix] Avoid transfering cached multi-modal items from P0 to P1 Signed-off-by: DarkLight1337 --- vllm/v1/engine/__init__.py | 3 ++- vllm/v1/engine/core.py | 6 ++--- vllm/v1/engine/mm_input_cache.py | 41 ++++++++++++++++++++++++++------ vllm/v1/engine/processor.py | 31 ++++++++++++++++-------- vllm/v1/request.py | 14 +++++++---- 5 files changed, 69 insertions(+), 26 deletions(-) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 0557d0c6c19d..1264e43c79d9 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -2,6 +2,7 @@ import enum import time +from collections.abc import Sequence from typing import Any, Optional, Union import msgspec @@ -52,7 +53,7 @@ class EngineCoreRequest( # Detokenizer, but set to None when it is added to EngineCoreClient. prompt: Optional[str] prompt_token_ids: list[int] - mm_inputs: Optional[list[MultiModalKwargs]] + mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] mm_hashes: Optional[list[str]] mm_placeholders: Optional[list[PlaceholderRange]] sampling_params: SamplingParams diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f58c77e4f165..077d49988962 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -31,7 +31,7 @@ from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) -from vllm.v1.engine.mm_input_cache import MMInputCacheServer +from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import ModelRunnerOutput @@ -105,7 +105,7 @@ def __init__( ) # Setup MM Input Mapper. - self.mm_input_cache_server = MMInputCacheServer( + self.mm_input_cache_server = MirroredProcessingCache( vllm_config.model_config) # Setup batch queue for pipeline parallelism. @@ -173,7 +173,7 @@ def add_request(self, request: EngineCoreRequest): # anything that has a hash must have a HIT cache entry here # as well. assert request.mm_inputs is not None - request.mm_inputs = self.mm_input_cache_server.get_and_update( + request.mm_inputs = self.mm_input_cache_server.get_and_update_p1( request.mm_inputs, request.mm_hashes) req = Request.from_engine_core_request(request) diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index 61a55d2499bd..ef5a2e5acb15 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence +from typing import Optional from vllm.envs import VLLM_MM_INPUT_CACHE_GIB from vllm.multimodal import MultiModalKwargs from vllm.multimodal.processing import ProcessingCache +from vllm.utils import is_list_of # The idea of multimodal preprocessing caching is based on having a client and # a server, where the client executes in the frontend process (=P0) and the @@ -11,9 +14,11 @@ # -- Client: # - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs # with built-in caching functionality, with mm_hash as its identifier. +# - MirroredProcessingCache to keep track of the cached entries and +# determine whether to send the MultiModalKwargs to P1. # # -- Server: -# - MMInputCacheServer to perform caching of the received MultiModalKwargs. +# - MirroredProcessingCache to store the MultiModalKwargs from P0. # # The caching for both client and server is mirrored, and this allows us # to avoid the serialization of "mm_inputs" (like pixel values) between @@ -25,26 +30,48 @@ # variable VLLM_MM_INPUT_CACHE_GIB. -class MMInputCacheServer: +class MirroredProcessingCache: def __init__(self, model_config): self.use_cache = not model_config.disable_mm_preprocessor_cache self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB, MultiModalKwargs) - def get_and_update( + def get_and_update_p0( self, - mm_inputs: list[MultiModalKwargs], + mm_inputs: Sequence[MultiModalKwargs], mm_hashes: list[str], - ) -> list[MultiModalKwargs]: + ) -> Sequence[Optional[MultiModalKwargs]]: assert len(mm_inputs) == len(mm_hashes) if not self.use_cache: + assert is_list_of(mm_inputs, MultiModalKwargs) return mm_inputs - full_mm_inputs = [] + full_mm_inputs = list[Optional[MultiModalKwargs]]() + for mm_input, mm_hash in zip(mm_inputs, mm_hashes): + if mm_hash in self.mm_cache: + mm_input = None + else: + self.mm_cache[mm_hash] = mm_input + + full_mm_inputs.append(mm_input) + + return full_mm_inputs + + def get_and_update_p1( + self, + mm_inputs: Sequence[Optional[MultiModalKwargs]], + mm_hashes: list[str], + ) -> Sequence[MultiModalKwargs]: + assert len(mm_inputs) == len(mm_hashes) + + if not self.use_cache: + assert is_list_of(mm_inputs, MultiModalKwargs) + return mm_inputs + + full_mm_inputs = list[MultiModalKwargs]() for mm_input, mm_hash in zip(mm_inputs, mm_hashes): - assert mm_hash is not None if mm_input is None: mm_input = self.mm_cache[mm_hash] else: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 403edddfcbee..dcddaefb3401 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import time -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import Optional, Union from vllm.config import VllmConfig @@ -18,6 +18,7 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) from vllm.v1.structured_output.utils import ( @@ -46,6 +47,8 @@ def __init__( self.tokenizer, mm_registry) + self.mm_input_cache_client = MirroredProcessingCache(self.model_config) + # Multi-modal hasher (for images) self.use_hash = ( not self.model_config.disable_mm_preprocessor_cache) or \ @@ -230,7 +233,7 @@ def process_inputs( self.tokenizer.get_lora_tokenizer(lora_request)) # Multimodal related. - sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None + sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None sorted_mm_positions: Optional[list[PlaceholderRange]] = None sorted_mm_hashes: Optional[list[str]] = None if decoder_inputs["type"] == "multimodal": @@ -240,7 +243,7 @@ def process_inputs( # from dictionaries to lists, and sort them by each item's position # in the input sequence. ( - sorted_item_modalities, + sorted_modalities, sorted_mm_positions, sorted_mm_hashes, ) = merge_and_sort_multimodal_metadata( @@ -253,22 +256,30 @@ def process_inputs( # This code flattens kwargs for individual items in a list and # sorts them by each item's position in the input sequence if there # are multiple modalities. - unique_modalities = set(sorted_item_modalities) + unique_modalities = set(sorted_modalities) if len(unique_modalities) > 1: - sorted_mm_inputs = [] + sorted_mm_inputs_ = [] used_indices = {modality: 0 for modality in unique_modalities} - for modality in sorted_item_modalities: + + for modality in sorted_modalities: items = decoder_mm_inputs.get_items(modality) item = items[used_indices[modality]] - sorted_mm_inputs.append(MultiModalKwargs.from_items([item - ])) + + sorted_mm_inputs_.append( + MultiModalKwargs.from_items([item])) used_indices[modality] += 1 else: - sorted_mm_inputs = [ + sorted_mm_inputs_ = [ MultiModalKwargs.from_items([item]) for item in - decoder_mm_inputs.get_items(sorted_item_modalities[0]) + decoder_mm_inputs.get_items(sorted_modalities[0]) ] + if sorted_mm_hashes is not None: + sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0( + sorted_mm_inputs_, sorted_mm_hashes) + else: + sorted_mm_inputs = sorted_mm_inputs_ + return EngineCoreRequest( request_id=request_id, prompt=decoder_inputs.get("prompt"), diff --git a/vllm/v1/request.py b/vllm/v1/request.py index daf59fd76e9a..6be72431dde5 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -3,17 +3,16 @@ import enum from typing import TYPE_CHECKING, Optional, Union +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams +from vllm.utils import is_list_of from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, EngineCoreRequest, FinishReason) from vllm.v1.structured_output.request import StructuredOutputRequest from vllm.v1.utils import ConstantList if TYPE_CHECKING: - from vllm.lora.request import LoRARequest - from vllm.multimodal import MultiModalKwargs - from vllm.multimodal.inputs import PlaceholderRange class Request: @@ -23,9 +22,9 @@ def __init__( request_id: str, prompt: Optional[str], prompt_token_ids: list[int], - multi_modal_inputs: Optional[list["MultiModalKwargs"]], + multi_modal_inputs: Optional[list[MultiModalKwargs]], multi_modal_hashes: Optional[list[str]], - multi_modal_placeholders: Optional[list["PlaceholderRange"]], + multi_modal_placeholders: Optional[list[PlaceholderRange]], sampling_params: SamplingParams, eos_token_id: Optional[int], arrival_time: float, @@ -75,6 +74,11 @@ def __init__( @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": + if request.mm_inputs is not None: + assert isinstance(request.mm_inputs, list) + assert is_list_of(request.mm_inputs, MultiModalKwargs), ( + "mm_inputs was not updated in EngineCore.add_request") + return cls( request_id=request.request_id, prompt=request.prompt, From f920e47c4e72d358dade789e2775e82c5f9f6967 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 9 Apr 2025 02:34:00 +0000 Subject: [PATCH 2/3] Revert variable name change Signed-off-by: DarkLight1337 --- vllm/v1/engine/processor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 511aa5fe7173..432900f55296 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -244,7 +244,7 @@ def process_inputs( # from dictionaries to lists, and sort them by each item's position # in the input sequence. ( - sorted_modalities, + sorted_item_modalities, sorted_mm_positions, sorted_mm_hashes, ) = merge_and_sort_multimodal_metadata( @@ -257,12 +257,12 @@ def process_inputs( # This code flattens kwargs for individual items in a list and # sorts them by each item's position in the input sequence if there # are multiple modalities. - unique_modalities = set(sorted_modalities) + unique_modalities = set(sorted_item_modalities) if len(unique_modalities) > 1: sorted_mm_inputs_ = [] used_indices = {modality: 0 for modality in unique_modalities} - for modality in sorted_modalities: + for modality in sorted_item_modalities: items = decoder_mm_inputs.get_items(modality) item = items[used_indices[modality]] @@ -272,7 +272,7 @@ def process_inputs( else: sorted_mm_inputs_ = [ MultiModalKwargs.from_items([item]) for item in - decoder_mm_inputs.get_items(sorted_modalities[0]) + decoder_mm_inputs.get_items(sorted_item_modalities[0]) ] if sorted_mm_hashes is not None: From fec139ade3c1d9ad5538b839b86b63379c7f0db8 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 9 Apr 2025 02:34:36 +0000 Subject: [PATCH 3/3] Improve naming Signed-off-by: DarkLight1337 --- vllm/v1/engine/processor.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 432900f55296..5f9c8ea4835f 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -259,27 +259,27 @@ def process_inputs( # are multiple modalities. unique_modalities = set(sorted_item_modalities) if len(unique_modalities) > 1: - sorted_mm_inputs_ = [] + orig_sorted_mm_inputs = [] used_indices = {modality: 0 for modality in unique_modalities} for modality in sorted_item_modalities: items = decoder_mm_inputs.get_items(modality) item = items[used_indices[modality]] - sorted_mm_inputs_.append( + orig_sorted_mm_inputs.append( MultiModalKwargs.from_items([item])) used_indices[modality] += 1 else: - sorted_mm_inputs_ = [ + orig_sorted_mm_inputs = [ MultiModalKwargs.from_items([item]) for item in decoder_mm_inputs.get_items(sorted_item_modalities[0]) ] if sorted_mm_hashes is not None: sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0( - sorted_mm_inputs_, sorted_mm_hashes) + orig_sorted_mm_inputs, sorted_mm_hashes) else: - sorted_mm_inputs = sorted_mm_inputs_ + sorted_mm_inputs = orig_sorted_mm_inputs return EngineCoreRequest( request_id=request_id,