diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index ac136698ee17..6d8f27530cee 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -22,8 +22,8 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) + BaseProcessingInfo, MultiModalHashes, + PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, @@ -279,24 +279,26 @@ def _cached_apply_hf_processor( prompt: Union[str, list[int]], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - ) -> tuple[list[int], MultiModalKwargs, bool]: + *, + return_mm_hashes: bool, + ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: # The processor logic is different for len(images) <= 2 vs > 2 # Since the processing cache assumes that the processor output is # invariant of how many images are passed per prompt, we only # perform caching for the most common case if mm_data_items.get_count("image", strict=False) > 2: - # This code path corresponds to the cache being disabled - return self._apply_hf_processor_main( + return self._apply_hf_processor( prompt=prompt, - mm_items=mm_data_items, + mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, - enable_hf_prompt_update=True, + return_mm_hashes=return_mm_hashes, ) return super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + return_mm_hashes=return_mm_hashes, ) diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index 15e126b0f4ce..99c226439ecb 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -19,8 +19,8 @@ from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) -from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.processing import (MultiModalHashes, PromptReplacement, + PromptUpdate, PromptUpdateDetails) from vllm.transformers_utils.tokenizer import AnyTokenizer from .intern_vit import InternVisionModel @@ -488,24 +488,26 @@ def _cached_apply_hf_processor( prompt: Union[str, list[int]], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - ) -> tuple[list[int], MultiModalKwargs, bool]: + *, + return_mm_hashes: bool, + ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: # The processor logic is different for len(images) <= 1 vs > 1 # Since the processing cache assumes that the processor output is # invariant of how many images are passed per prompt, we only # perform caching for the most common case if mm_data_items.get_count("image", strict=False) > 1: - # This code path corresponds to the cache being disabled - return self._apply_hf_processor_main( + return self._apply_hf_processor( prompt=prompt, - mm_items=mm_data_items, + mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, - enable_hf_prompt_update=True, + return_mm_hashes=return_mm_hashes, ) return super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + return_mm_hashes=return_mm_hashes, ) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 8862b2679f93..16f5327ee79b 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -396,14 +396,12 @@ def _build_llava_or_pixtral_hf_processor( dummy_inputs: BaseDummyInputsBuilder[_I], *, cache: Optional[ProcessingCache] = None, - enable_sanity_checks: bool = True, ) -> BaseMultiModalProcessor: if isinstance(info, PixtralHFProcessingInfo): return PixtralHFMultiModalProcessor( info, dummy_inputs, # type: ignore cache=cache, - enable_sanity_checks=enable_sanity_checks, ) if isinstance(info, LlavaProcessingInfo): @@ -411,7 +409,6 @@ def _build_llava_or_pixtral_hf_processor( info, dummy_inputs, # type: ignore cache=cache, - enable_sanity_checks=enable_sanity_checks, ) raise NotImplementedError(type(info)) diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index f8e9e3181367..12c87dc0f2af 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -312,14 +312,12 @@ def _build_mistral3_processor( dummy_inputs: BaseDummyInputsBuilder[_I], *, cache: Optional[ProcessingCache] = None, - enable_sanity_checks: bool = True, ) -> BaseMultiModalProcessor: assert isinstance(info, Mistral3ProcessingInfo) return Mistral3MultiModalProcessor( info, dummy_inputs, # type: ignore cache=cache, - enable_sanity_checks=enable_sanity_checks, ) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 73fd80146955..d756b3b8a7ca 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -36,8 +36,9 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) + BaseProcessingInfo, MultiModalHashes, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import (MistralTokenizer, @@ -271,15 +272,19 @@ def _cached_apply_hf_processor( prompt: Union[str, list[int]], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - ) -> tuple[list[int], MultiModalKwargs, bool]: - prompt_ids, mm_kwargs, _ = super()._cached_apply_hf_processor( + *, + return_mm_hashes: bool, + ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: + prompt_ids, mm_kwargs, mm_hashes, _ = super( + )._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, + return_mm_hashes=return_mm_hashes, ) # NOTE: The tokens are already inserted by the chat template - return prompt_ids, mm_kwargs, True + return prompt_ids, mm_kwargs, mm_hashes, True @MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor, diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 87131122e6f2..d6ba8f1bcffe 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -876,6 +876,16 @@ def find_mm_placeholders( _V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]") +class ProcessingCacheOptionalItem(NamedTuple): + key: str + value: Optional[MultiModalKwargsItem] + + +class ProcessingCacheItem(NamedTuple): + key: str + value: MultiModalKwargsItem + + class ProcessingCache: @staticmethod @@ -980,6 +990,22 @@ def get( return self._cache.get(cache_key) + def get_item( + self, + model_id: str, + modality: str, + input_item: object, + input_kwargs: Mapping[str, object], + ) -> ProcessingCacheOptionalItem: + cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, + **{modality: input_item}, + **input_kwargs) + + return ProcessingCacheOptionalItem( + key=cache_key, + value=self._cache.get(cache_key), + ) + def put( self, model_id: str, @@ -997,6 +1023,9 @@ def put( **input_kwargs) self._cache[cache_key] = output_kwargs + def put_item(self, item: ProcessingCacheItem) -> None: + self._cache[item.key] = item.value + class BaseProcessingInfo: """Base class to provide the information necessary for data processing.""" @@ -1052,6 +1081,11 @@ def get_allowed_mm_limits(self) -> Mapping[str, int]: _I = TypeVar("_I", bound=BaseProcessingInfo) +MultiModalHashes = dict[str, list[str]] +""" +A collection of hashes with a similar structure as :class:`MultiModalKwargs`. +""" + class BaseMultiModalProcessor(ABC, Generic[_I]): """ @@ -1064,14 +1098,12 @@ def __init__(self, info: _I, dummy_inputs: "BaseDummyInputsBuilder[_I]", *, - cache: Optional[ProcessingCache] = None, - enable_sanity_checks: bool = True) -> None: + cache: Optional[ProcessingCache] = None) -> None: super().__init__() self.info = info self.dummy_inputs = dummy_inputs self.cache = cache - self.enable_sanity_checks = enable_sanity_checks self.data_parser = self._get_data_parser() @@ -1340,46 +1372,144 @@ def _apply_hf_processor_main( return prompt_ids, mm_kwargs, False + def _get_cache_missing_items( + self, + cache: ProcessingCache, + mm_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> tuple[dict[str, list[ProcessingCacheOptionalItem]], dict[ + str, list[object]]]: + model_id = self.info.model_id + + mm_cache_items = { + modality: [ + cache.get_item(model_id, modality, item, + hf_processor_mm_kwargs) for item in items + ] + for modality, items in mm_data_items.items() + } + + mm_missing_idxs = { + modality: [ + idx for idx, item in enumerate(cache_items) + if item.value is None + ] + for modality, cache_items in mm_cache_items.items() + } + mm_missing_data = { + modality: [mm_data_items[modality][idx] for idx in idxs] + for modality, idxs in mm_missing_idxs.items() + } + + return mm_cache_items, mm_missing_data + + def _hash_mm_items( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> MultiModalHashes: + """Create MM hashes to be returned (only used in V1).""" + model_id = self.info.model_id + + return { + modality: [ + MultiModalHasher.hash_kwargs(model_id=model_id, + **{modality: item}, + **hf_processor_mm_kwargs) + for item in items + ] + for modality, items in mm_items.items() + } + + def _merge_mm_kwargs( + self, + cache: ProcessingCache, + mm_cache_items: dict[str, list[ProcessingCacheOptionalItem]], + mm_missing_data: dict[str, list[object]], + mm_missing_kwargs: MultiModalKwargs, + ) -> dict[str, list[ProcessingCacheItem]]: + mm_missing_next_idx = {modality: 0 for modality in mm_missing_data} + + merged_items = defaultdict[str, list[ProcessingCacheItem]](list) + for modality, cache_items in mm_cache_items.items(): + for cache_item in cache_items: + if cache_item.value is None: + kw_item = mm_missing_kwargs.get_item( + modality, + mm_missing_next_idx[modality], + ) + cache_item_new = ProcessingCacheItem( + key=cache_item.key, + value=kw_item, + ) + + cache.put_item(cache_item_new) + mm_missing_next_idx[modality] += 1 + else: + cache_item_new = ProcessingCacheItem( + key=cache_item.key, + value=cache_item.value, + ) + + merged_items[modality].append(cache_item_new) + + return dict(merged_items) + + def _apply_hf_processor( + self, + prompt: Union[str, list[int]], + mm_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + *, + return_mm_hashes: bool, + ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: + ( + prompt_ids, + mm_kwargs, + is_update_applied, + ) = self._apply_hf_processor_main( + prompt=prompt, + mm_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + enable_hf_prompt_update=True, + ) + + mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs) + if return_mm_hashes else None) + + return prompt_ids, mm_kwargs, mm_hashes, is_update_applied + def _cached_apply_hf_processor( self, prompt: Union[str, list[int]], mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], - ) -> tuple[list[int], MultiModalKwargs, bool]: + *, + return_mm_hashes: bool, + ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: """ Apply the HF processor on the full prompt text, caching the results and reusing cached results. """ cache = self.cache - model_id = self.info.model_id _, passthrough_data = self._get_hf_mm_data(mm_data_items) if cache is None or passthrough_data: - return self._apply_hf_processor_main( + return self._apply_hf_processor( prompt=prompt, - mm_items=mm_data_items, + mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, - enable_hf_prompt_update=True, + return_mm_hashes=return_mm_hashes, ) - mm_maybe_cached_kw_items = { - modality: [ - cache.get(model_id, modality, item, hf_processor_mm_kwargs) - for item in items - ] - for modality, items in mm_data_items.items() - } - - mm_missing_idxs = { - modality: - [idx for idx, item in enumerate(kw_items) if item is None] - for modality, kw_items in mm_maybe_cached_kw_items.items() - } - mm_missing_data = { - modality: [mm_data_items[modality][idx] for idx in idxs] - for modality, idxs in mm_missing_idxs.items() - } - mm_missing_data_items = self._to_mm_items(mm_missing_data) + ( + mm_cache_items, + mm_missing_data, + ) = self._get_cache_missing_items( + cache=cache, + mm_data_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) # NOTE: `prompt` does not correspond to `mm_missing_data_items`, # so we can't apply prompt updates until the new multimodal @@ -1390,48 +1520,29 @@ def _cached_apply_hf_processor( is_update_applied, ) = self._apply_hf_processor_main( prompt=prompt, - mm_items=mm_missing_data_items, + mm_items=self._to_mm_items(mm_missing_data), hf_processor_mm_kwargs=hf_processor_mm_kwargs, enable_hf_prompt_update=False, ) - mm_missing_next_idx = { - modality: 0 - for modality in mm_missing_data_items - } - - merged_kw_items = list[MultiModalKwargsItem]() - for modality, kw_items in mm_maybe_cached_kw_items.items(): - for idx, kw_item in enumerate(kw_items): - if kw_item is None: - kw_item = mm_missing_kwargs.get_item( - modality, - mm_missing_next_idx[modality], - ) - - cache.put( - model_id, - modality, - mm_data_items[modality][idx], - hf_processor_mm_kwargs, - kw_item, - ) - - mm_missing_next_idx[modality] += 1 - - merged_kw_items.append(kw_item) + mm_cache_items_merged = self._merge_mm_kwargs( + cache, + mm_cache_items=mm_cache_items, + mm_missing_data=mm_missing_data, + mm_missing_kwargs=mm_missing_kwargs, + ) - if self.enable_sanity_checks: - mm_missing_counts = mm_missing_data_items.get_all_counts() - assert all( - item_count == mm_missing_counts[modality] - for modality, item_count in mm_missing_next_idx.items()), dict( - mm_missing_next_idx=mm_missing_next_idx, - mm_missing_counts=mm_missing_counts) + mm_kwargs = MultiModalKwargs.from_items([ + item.value for cache_items in mm_cache_items_merged.values() + for item in cache_items + ]) - mm_kwargs = MultiModalKwargs.from_items(merged_kw_items) + mm_hashes = { + modality: [item.key for item in cache_items] + for modality, cache_items in mm_cache_items_merged.items() + } if return_mm_hashes else None - return prompt_ids, mm_kwargs, is_update_applied + return prompt_ids, mm_kwargs, mm_hashes, is_update_applied def _bind_and_group_updates( self, @@ -1569,27 +1680,6 @@ def _validate_mm_placeholders( "model (usually arising from an inconsistency between " "`_call_hf_processor` and `_get_prompt_updates`).") - def _hash_mm_items( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> dict[str, list[str]]: - """Create MM hashes to be returned (only used in V1).""" - - # TODO: Use these hash keys for caching operations in apply_hf_processor - # instead of rehashing. - model_id = self.info.model_id - - return { - modality: [ - MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: item}, - **hf_processor_mm_kwargs) - for item in items - ] - for modality, items in mm_items.items() - } - def _maybe_apply_prompt_updates( self, mm_items: MultiModalDataItems, @@ -1655,17 +1745,16 @@ def apply( """ mm_items = self._to_mm_items(mm_data) - mm_hashes = (self._hash_mm_items(mm_items, hf_processor_mm_kwargs) - if return_mm_hashes else None) - ( prompt_ids, mm_kwargs, + mm_hashes, is_update_applied, ) = self._cached_apply_hf_processor( prompt, mm_items, hf_processor_mm_kwargs, + return_mm_hashes=return_mm_hashes, ) prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates( @@ -1717,28 +1806,12 @@ def create_decoder_prompt( """Create input prompt for the decoder.""" return prompt - def apply( + def _get_enc_dec_inputs( self, prompt: Union[str, list[int]], mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - return_mm_hashes: bool = False, - ) -> MultiModalEncDecInputs: - """ - Process multi-modal inputs to be used in vLLM. - The main processing steps are modified to fit encoder-decoder model: - 1. Create encoder prompt from input prompt text. - 2. Apply the HF processor on encoder prompt. - 3. Copy the input prompt text as decoder prompt inputs. - """ - encoder_prompt = self.create_encoder_prompt(prompt, mm_data) - encoder_inputs = super().apply( - encoder_prompt, - mm_data, - hf_processor_mm_kwargs, - return_mm_hashes, - ) - + encoder_inputs: MultiModalInputs, + ): tokenizer = self.info.get_tokenizer() decoder_prompt = self.create_decoder_prompt(prompt, mm_data) if isinstance(decoder_prompt, str): @@ -1758,3 +1831,31 @@ def apply( "prompt_token_ids": decoder_prompt_ids }) return mm_inputs + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + return_mm_hashes: bool = False, + ) -> MultiModalEncDecInputs: + """ + Process multi-modal inputs to be used in vLLM. + The main processing steps are modified to fit encoder-decoder model: + 1. Create encoder prompt from input prompt text. + 2. Apply the HF processor on encoder prompt. + 3. Copy the input prompt text as decoder prompt inputs. + """ + encoder_prompt = self.create_encoder_prompt(prompt, mm_data) + encoder_inputs = super().apply( + encoder_prompt, + mm_data, + hf_processor_mm_kwargs, + return_mm_hashes, + ) + + return self._get_enc_dec_inputs( + prompt=prompt, + mm_data=mm_data, + encoder_inputs=encoder_inputs, + )