Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,9 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
# the next step without re-scheduling.
self._skip_scheduling_next_step = False

# Don't keep the dummy data in memory
self.reset_mm_cache()

def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).

Expand Down Expand Up @@ -913,6 +916,10 @@ def has_unfinished_requests_for_virtual_engine(
"""
return self.scheduler[virtual_engine].has_unfinished_seqs()

def reset_mm_cache(self) -> bool:
"""Reset the multi-modal cache."""
return self.input_preprocessor.mm_registry.reset_processor_cache()

def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices."""

Expand Down
5 changes: 5 additions & 0 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,11 @@ def put(
def put_item(self, item: ProcessingCacheItem) -> None:
self._cache[item.key] = item.value

def reset(self) -> bool:
self._cache.clear()

return True


class BaseProcessingInfo:
"""Base class to provide the information necessary for data processing."""
Expand Down
14 changes: 10 additions & 4 deletions vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ def __init__(self) -> None:

self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB)

def reset_processor_cache(self) -> bool:
"""Reset the multi-modal processing cache."""
self._processing_cache.reset()

return True # Success

@deprecated("Legacy input processor/mapper pipeline has been removed. "
"Please update your model runner to use "
"`seq_group_metadata.multi_modal_data` directly without "
Expand All @@ -106,7 +112,7 @@ def get_max_tokens_per_item_by_modality(
if not model_config.is_multimodal_model:
return {}

processor = self.create_processor(model_config, disable_cache=True)
processor = self.create_processor(model_config, disable_cache=False)
profiler = MultiModalProfiler(processor)

seq_len = model_config.max_model_len
Expand Down Expand Up @@ -190,7 +196,7 @@ def get_mm_limits_per_prompt(
if not model_config.is_multimodal_model:
return {}

processor = self.create_processor(model_config, disable_cache=True)
processor = self.create_processor(model_config, disable_cache=False)
profiler = MultiModalProfiler(processor)
return profiler.get_mm_limits()

Expand Down Expand Up @@ -286,7 +292,7 @@ def get_decoder_dummy_data(

The model is identified by ``model_config``.
"""
processor = self.create_processor(model_config, disable_cache=True)
processor = self.create_processor(model_config, disable_cache=False)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)

Expand All @@ -310,7 +316,7 @@ def get_encoder_dummy_data(

The model is identified by ``model_config``.
"""
processor = self.create_processor(model_config, disable_cache=True)
processor = self.create_processor(model_config, disable_cache=False)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)

Expand Down
12 changes: 12 additions & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def __init__(self,
self.batch_queue = queue.Queue(self.batch_queue_size)
self.vllm_config = vllm_config

# Don't keep the dummy data in memory
self.reset_mm_cache()

def _initialize_kv_caches(
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
start = time.time()
Expand Down Expand Up @@ -277,6 +280,15 @@ def shutdown(self):
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)

def reset_mm_cache(self):
# NOTE: Since this is mainly for debugging, we don't attempt to
# re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
if self.scheduler.get_num_unfinished_requests():
logger.warning("Resetting the multi-modal cache when requests are "
"in progress may lead to desynced internal caches.")

self.mm_input_cache_server.reset()

def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()

Expand Down
9 changes: 9 additions & 0 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def add_request(self, request: EngineCoreRequest) -> None:
def profile(self, is_start: bool = True) -> None:
raise NotImplementedError

def reset_mm_cache(self) -> None:
raise NotImplementedError

def reset_prefix_cache(self) -> None:
raise NotImplementedError

Expand Down Expand Up @@ -214,6 +217,9 @@ def shutdown(self) -> None:
def profile(self, is_start: bool = True) -> None:
self.engine_core.profile(is_start)

def reset_mm_cache(self) -> None:
self.engine_core.reset_mm_cache()

def reset_prefix_cache(self) -> None:
self.engine_core.reset_prefix_cache()

Expand Down Expand Up @@ -600,6 +606,9 @@ def abort_requests(self, request_ids: list[str]) -> None:
def profile(self, is_start: bool = True) -> None:
self.call_utility("profile", is_start)

def reset_mm_cache(self) -> None:
self.call_utility("reset_mm_cache")

def reset_prefix_cache(self) -> None:
self.call_utility("reset_prefix_cache")

Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ def start_profile(self):
def stop_profile(self):
self.engine_core.profile(False)

def reset_mm_cache(self):
self.processor.mm_registry.reset_processor_cache()
self.processor.mm_input_cache_client.reset()
self.engine_core.reset_mm_cache()

def reset_prefix_cache(self, device: Optional[Device] = None):
self.engine_core.reset_prefix_cache()

Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/engine/mm_input_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,8 @@ def get_and_update_p1(
full_mm_inputs.append(mm_input)

return full_mm_inputs

def reset(self) -> bool:
self.mm_cache.clear()

return True
4 changes: 4 additions & 0 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def __init__(
self.use_hash = self.mm_input_cache_client.use_cache or \
self.cache_config.enable_prefix_caching

@property
def mm_registry(self):
return self.input_preprocessor.mm_registry

def _validate_logprobs(
self,
params: SamplingParams,
Expand Down