Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,9 @@ async def start_profile(self) -> None:
async def stop_profile(self) -> None:
self.engine.stop_profile()

async def reset_mm_cache(self) -> None:
self.engine.reset_mm_cache()

async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
self.engine.reset_prefix_cache(device)
Expand Down
7 changes: 7 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,9 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
# the next step without re-scheduling.
self._skip_scheduling_next_step = False

# Don't keep the dummy data in memory
self.reset_mm_cache()

def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).

Expand Down Expand Up @@ -913,6 +916,10 @@ def has_unfinished_requests_for_virtual_engine(
"""
return self.scheduler[virtual_engine].has_unfinished_seqs()

def reset_mm_cache(self) -> bool:
"""Reset the multi-modal cache."""
return self.input_preprocessor.mm_registry.reset_processor_cache()

def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices."""

Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ class RPCUProfileRequest(Enum):
STOP_PROFILE = 2


class RPCResetMultiModalCacheRequest(Enum):
RESET = 1


@dataclass
class RPCResetPrefixCacheRequest:
device: Device
Expand Down Expand Up @@ -164,6 +168,7 @@ class RPCAdapterLoadedResponse:

RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest, RPCLoadAdapterRequest,
RPCResetMultiModalCacheRequest,
RPCResetPrefixCacheRequest, RPCSleepRequest,
RPCWakeUpRequest, RPCIsSleepingRequest]

Expand Down
8 changes: 8 additions & 0 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
RPCIsSleepingResponse,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetMultiModalCacheRequest,
RPCResetPrefixCacheRequest,
RPCSleepRequest, RPCStartupRequest,
RPCStartupResponse,
Expand Down Expand Up @@ -687,6 +688,13 @@ async def stop_profile(self) -> None:
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)

async def reset_mm_cache(self) -> None:
"""Reset the multi-modal cache"""

await self._send_one_way_rpc_request(
request=RPCResetMultiModalCacheRequest.RESET,
socket=self.input_socket)

async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
"""Reset the prefix cache"""
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
RPCIsSleepingResponse,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetMultiModalCacheRequest,
RPCResetPrefixCacheRequest,
RPCSleepRequest, RPCStartupRequest,
RPCStartupResponse,
Expand Down Expand Up @@ -269,6 +270,8 @@ def handle_new_input(self):
self.stop_profile()
elif isinstance(request, RPCLoadAdapterRequest):
self._handle_load_adapter_request(request)
elif isinstance(request, RPCResetMultiModalCacheRequest):
self.reset_mm_cache()
elif isinstance(request, RPCResetPrefixCacheRequest):
self.reset_prefix_cache()
elif isinstance(request, RPCSleepRequest):
Expand Down Expand Up @@ -409,6 +412,9 @@ def start_profile(self) -> None:
def stop_profile(self) -> None:
self.engine.stop_profile()

def reset_mm_cache(self) -> bool:
return self.engine.reset_mm_cache()

def reset_prefix_cache(self) -> bool:
return self.engine.reset_prefix_cache()

Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,11 @@ async def stop_profile(self) -> None:
"""Start profiling the engine"""
...

@abstractmethod
async def reset_mm_cache(self) -> None:
"""Reset the multi-modal cache"""
...

@abstractmethod
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ async def build_async_engine_client(

async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine:

# Don't keep the dummy data in memory
await engine.reset_mm_cache()

yield engine


Expand Down
5 changes: 5 additions & 0 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,11 @@ def put(
def put_item(self, item: ProcessingCacheItem) -> None:
self._cache[item.key] = item.value

def reset(self) -> bool:
self._cache.clear()

return True


class BaseProcessingInfo:
"""Base class to provide the information necessary for data processing."""
Expand Down
14 changes: 10 additions & 4 deletions vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ def __init__(self) -> None:

self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB)

def reset_processor_cache(self) -> bool:
"""Reset the multi-modal processing cache."""
self._processing_cache.reset()

return True # Success

@deprecated("Legacy input processor/mapper pipeline has been removed. "
"Please update your model runner to use "
"`seq_group_metadata.multi_modal_data` directly without "
Expand All @@ -106,7 +112,7 @@ def get_max_tokens_per_item_by_modality(
if not model_config.is_multimodal_model:
return {}

processor = self.create_processor(model_config, disable_cache=True)
processor = self.create_processor(model_config, disable_cache=False)
profiler = MultiModalProfiler(processor)

seq_len = model_config.max_model_len
Expand Down Expand Up @@ -190,7 +196,7 @@ def get_mm_limits_per_prompt(
if not model_config.is_multimodal_model:
return {}

processor = self.create_processor(model_config, disable_cache=True)
processor = self.create_processor(model_config, disable_cache=False)
profiler = MultiModalProfiler(processor)
return profiler.get_mm_limits()

Expand Down Expand Up @@ -286,7 +292,7 @@ def get_decoder_dummy_data(

The model is identified by ``model_config``.
"""
processor = self.create_processor(model_config, disable_cache=True)
processor = self.create_processor(model_config, disable_cache=False)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)

Expand All @@ -310,7 +316,7 @@ def get_encoder_dummy_data(

The model is identified by ``model_config``.
"""
processor = self.create_processor(model_config, disable_cache=True)
processor = self.create_processor(model_config, disable_cache=False)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)

Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,11 @@ async def start_profile(self) -> None:
async def stop_profile(self) -> None:
await self.engine_core.profile_async(False)

async def reset_mm_cache(self) -> None:
self.processor.mm_registry.reset_processor_cache()
self.processor.mm_input_cache_client.reset()
await self.engine_core.reset_mm_cache_async()

async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
if device == Device.CPU:
Expand Down
9 changes: 9 additions & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,15 @@ def shutdown(self):
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)

def reset_mm_cache(self):
# NOTE: Since this is mainly for debugging, we don't attempt to
# re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
if self.scheduler.get_num_unfinished_requests():
logger.warning("Resetting the multi-modal cache when requests are "
"in progress may lead to desynced internal caches.")

self.mm_input_cache_server.reset()

def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()

Expand Down
15 changes: 15 additions & 0 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def add_request(self, request: EngineCoreRequest) -> None:
def profile(self, is_start: bool = True) -> None:
raise NotImplementedError

def reset_mm_cache(self) -> None:
raise NotImplementedError

def reset_prefix_cache(self) -> None:
raise NotImplementedError

Expand Down Expand Up @@ -143,6 +146,9 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
async def profile_async(self, is_start: bool = True) -> None:
raise NotImplementedError

async def reset_mm_cache_async(self) -> None:
raise NotImplementedError

async def reset_prefix_cache_async(self) -> None:
raise NotImplementedError

Expand Down Expand Up @@ -214,6 +220,9 @@ def shutdown(self) -> None:
def profile(self, is_start: bool = True) -> None:
self.engine_core.profile(is_start)

def reset_mm_cache(self) -> None:
self.engine_core.reset_mm_cache()

def reset_prefix_cache(self) -> None:
self.engine_core.reset_prefix_cache()

Expand Down Expand Up @@ -600,6 +609,9 @@ def abort_requests(self, request_ids: list[str]) -> None:
def profile(self, is_start: bool = True) -> None:
self.call_utility("profile", is_start)

def reset_mm_cache(self) -> None:
self.call_utility("reset_mm_cache")

def reset_prefix_cache(self) -> None:
self.call_utility("reset_prefix_cache")

Expand Down Expand Up @@ -787,6 +799,9 @@ async def abort_requests_async(self, request_ids: list[str]) -> None:
async def profile_async(self, is_start: bool = True) -> None:
await self.call_utility_async("profile", is_start)

async def reset_mm_cache_async(self) -> None:
await self.call_utility_async("reset_mm_cache")

async def reset_prefix_cache_async(self) -> None:
await self.call_utility_async("reset_prefix_cache")

Expand Down
8 changes: 8 additions & 0 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def __init__(
# for v0 compatibility
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore

# Don't keep the dummy data in memory
self.reset_mm_cache()

@classmethod
def from_vllm_config(
cls,
Expand Down Expand Up @@ -240,6 +243,11 @@ def start_profile(self):
def stop_profile(self):
self.engine_core.profile(False)

def reset_mm_cache(self):
self.processor.mm_registry.reset_processor_cache()
self.processor.mm_input_cache_client.reset()
self.engine_core.reset_mm_cache()

def reset_prefix_cache(self, device: Optional[Device] = None):
self.engine_core.reset_prefix_cache()

Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/engine/mm_input_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,8 @@ def get_and_update_p1(
full_mm_inputs.append(mm_input)

return full_mm_inputs

def reset(self) -> bool:
self.mm_cache.clear()

return True
4 changes: 4 additions & 0 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def __init__(
self.use_hash = self.mm_input_cache_client.use_cache or \
self.cache_config.enable_prefix_caching

@property
def mm_registry(self):
return self.input_preprocessor.mm_registry

def _validate_logprobs(
self,
params: SamplingParams,
Expand Down