1010from vllm .config import CompilationLevel , VllmConfig
1111from vllm .distributed .parallel_state import graph_capture
1212from vllm .forward_context import set_forward_context
13- from vllm .inputs import INPUT_REGISTRY , InputRegistry
13+ from vllm .inputs import INPUT_REGISTRY
1414from vllm .logger import init_logger
1515from vllm .model_executor .model_loader import get_model
16- from vllm .multimodal import MultiModalKwargs
16+ from vllm .multimodal import MULTIMODAL_REGISTRY , MultiModalKwargs
1717from vllm .sampling_params import SamplingType
1818from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , DeviceMemoryProfiler ,
1919 LayerBlockType , cdiv , is_pin_memory_available )
2020from vllm .v1 .attention .backends .flash_attn import (FlashAttentionBackend ,
2121 FlashAttentionMetadata )
22+ from vllm .v1 .engine .mm_input_mapper import MMInputMapperClient
2223from vllm .v1 .outputs import ModelRunnerOutput
2324from vllm .v1 .sample .metadata import SamplingMetadata
2425from vllm .v1 .worker .gpu_input_batch import CachedRequestState , InputBatch
@@ -35,7 +36,6 @@ def __init__(
3536 self ,
3637 vllm_config : VllmConfig ,
3738 device : torch .device ,
38- input_registry : InputRegistry = INPUT_REGISTRY ,
3939 ):
4040 self .vllm_config = vllm_config
4141 self .model_config = vllm_config .model_config
@@ -77,7 +77,12 @@ def __init__(
7777 self .hidden_size = model_config .get_hidden_size ()
7878
7979 # Multi-modal data support
80- self .input_registry = input_registry
80+ self .input_registry = INPUT_REGISTRY
81+ self .mm_registry = MULTIMODAL_REGISTRY
82+ # NOTE: mm_input_mapper is only used for memory profiling.
83+ self .mm_input_mapper = MMInputMapperClient (self .model_config )
84+ self .max_num_encoder_input_tokens = self .scheduler_config .max_num_encoder_input_tokens # noqa: E501
85+ self .encoder_cache_size = self .scheduler_config .encoder_cache_size
8186
8287 # Lazy initialization
8388 # self.model: nn.Module # Set after load_model
@@ -599,8 +604,6 @@ def _dummy_run(
599604 return hidden_states
600605
601606 def profile_run (self ) -> None :
602- # TODO(woosuk): Profile the max memory usage of the encoder and
603- # the encoder cache.
604607 # use an empty tensor instead of `None`` to force Dynamo to pass
605608 # it by reference, rather by specializing on the value `None`.
606609 # the `dtype` argument does not matter, and we use `float32` as
@@ -612,6 +615,57 @@ def profile_run(self) -> None:
612615 torch .tensor ([], dtype = torch .float32 , device = self .device )
613616 for _ in range (self .num_attn_layers )
614617 ]
618+
619+ # Profile with multimodal encoder & encoder cache.
620+ # TODO (ywang96): generalize this beyond image modality since
621+ # mm_input_mapper only supports image inputs.
622+ if self .is_multimodal_model :
623+
624+ # Create dummy batch of multimodal inputs.
625+ dummy_request_data = self .input_registry .dummy_data_for_profiling (
626+ model_config = self .model_config ,
627+ seq_len = self .max_num_tokens ,
628+ mm_registry = self .mm_registry ,
629+ )
630+ dummy_mm_data = dummy_request_data .multi_modal_data
631+ dummy_mm_kwargs , _ = self .mm_input_mapper .process_inputs (
632+ mm_data = dummy_mm_data ,
633+ mm_hashes = None ,
634+ mm_processor_kwargs = None ,
635+ precomputed_mm_inputs = None )
636+
637+ # NOTE: Currently model is profiled with a single non-text
638+ # modality even when it supports multiple.
639+ max_tokens_per_mm_item = max (
640+ self .mm_registry .get_max_tokens_per_item_by_modality (
641+ self .model_config ).values ())
642+
643+ max_num_mm_items = min (
644+ self .max_num_encoder_input_tokens ,
645+ self .encoder_cache_size ) // max_tokens_per_mm_item
646+
647+ # Dummy data definition in V0 may contain multiple multimodal items
648+ # (e.g, multiple images) for a single request, therefore here we
649+ # always replicate first item by max_num_mm_items times since in V1
650+ # they are scheduled to be processed separately.
651+ batched_dummy_mm_inputs = MultiModalKwargs .batch (
652+ [dummy_mm_kwargs [0 ]] * max_num_mm_items )
653+ batched_dummy_mm_inputs = MultiModalKwargs .as_kwargs (
654+ batched_dummy_mm_inputs , device = self .device )
655+
656+ # Run multimodal encoder.
657+ dummy_encoder_outputs = self .model .get_multimodal_embeddings (
658+ ** batched_dummy_mm_inputs )
659+ assert len (dummy_encoder_outputs ) == max_num_mm_items , (
660+ "Expected dimension 0 of encoder outputs to match the number "
661+ f"of multimodal data items: { max_num_mm_items } , got "
662+ f"{ len (dummy_encoder_outputs )= } instead. This is most likely "
663+ "due to the 'get_multimodal_embeddings' method of the model "
664+ "not implemented correctly." )
665+
666+ # Cache the dummy encoder outputs.
667+ self .encoder_cache ["tmp" ] = dict (enumerate (dummy_encoder_outputs ))
668+
615669 # Trigger compilation for general shape.
616670 hidden_states = self ._dummy_run (self .model , self .max_num_tokens ,
617671 dummy_kv_caches )
@@ -620,6 +674,7 @@ def profile_run(self) -> None:
620674 # TODO(woosuk): Consider the memory usage of the sampler.
621675 torch .cuda .synchronize ()
622676 del hidden_states , logits
677+ self .encoder_cache .clear ()
623678 gc .collect ()
624679
625680 def capture_model (self ) -> None :
0 commit comments