diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index ca636bf5a6f7..0ec5e2e64425 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -23,9 +23,10 @@ from vllm.logger import init_logger from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest +from vllm.multimodal.inputs import MultiModalKwargs from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import make_zmq_socket, resolve_obj_by_qualname +from vllm.utils import is_list_of, make_zmq_socket, resolve_obj_by_qualname from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface @@ -194,7 +195,9 @@ def _initialize_kv_caches( "warmup model) took %.2f seconds"), elapsed) return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config - def add_request(self, request: EngineCoreRequest): + def add_request(self, request: Union[EngineCoreRequest, Request]): + if type(request) is EngineCoreRequest: + request = self._preprocess_add_request(request) """Add request to the scheduler.""" if pooling_params := request.pooling_params: supported_pooling_tasks = ( @@ -203,27 +206,30 @@ def add_request(self, request: EngineCoreRequest): raise ValueError(f"Unsupported task: {pooling_params.task!r} " f"Supported tasks: {supported_pooling_tasks}") - if request.mm_hashes is not None: + if request.mm_hashes: # Here, if hash exists for a multimodal input, then it will be # fetched from the cache, else it will be added to the cache. # Note that the cache here is mirrored with the client cache, so # anything that has a hash must have a HIT cache entry here # as well. - assert request.mm_inputs is not None - request.mm_inputs = self.mm_input_cache_server.get_and_update_p1( + assert request.mm_inputs + updated_mm_inputs = self.mm_input_cache_server.get_and_update_p1( request.mm_inputs, request.mm_hashes) + assert isinstance(updated_mm_inputs, list) + assert is_list_of(updated_mm_inputs, MultiModalKwargs), ( + "Invalid updated mm_inputs in EngineCore.add_request") + request.mm_inputs = updated_mm_inputs - req = Request.from_engine_core_request(request) - if req.use_structured_output: + if request.use_structured_output: # Start grammar compilation asynchronously - self.structured_output_manager.grammar_init(req) + self.structured_output_manager.grammar_init(request) - if req.kv_transfer_params is not None and ( + if request.kv_transfer_params is not None and ( not self.scheduler.get_kv_connector()): logger.warning("Got kv_transfer_params, but no KVConnector found. " "Disabling KVTransfer for this request.") - self.scheduler.add_request(req) + self.scheduler.add_request(request) def abort_requests(self, request_ids: list[str]): """Abort requests from the scheduler.""" @@ -385,6 +391,13 @@ def save_tensorized_model( self.model_executor.save_tensorized_model( tensorizer_config=tensorizer_config, ) + def _preprocess_add_request(self, request: EngineCoreRequest) -> Request: + """Preprocess the request. + + This function could be directly used in input processing thread to allow + request initialization running in parallel with Model forward""" + return Request.from_engine_core_request(request) + class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" @@ -766,10 +779,11 @@ def process_input_sockets(self, input_addresses: list[str], bytes(type_frame.buffer)) # Deserialize the request data. - decoder = add_request_decoder if ( - request_type - == EngineCoreRequestType.ADD) else generic_decoder - request = decoder.decode(data_frames) + if request_type == EngineCoreRequestType.ADD: + request = add_request_decoder.decode(data_frames) + request = self._preprocess_add_request(request) + else: + request = generic_decoder.decode(data_frames) # Push to input queue for core busy loop. self.input_queue.put_nowait((request_type, request)) @@ -915,7 +929,7 @@ def shutdown(self): if dp_group := getattr(self, "dp_group", None): stateless_destroy_torch_distributed_process_group(dp_group) - def add_request(self, request: EngineCoreRequest): + def add_request(self, request: Union[EngineCoreRequest, Request]): if self.has_coordinator and request.current_wave != self.current_wave: if request.current_wave > self.current_wave: self.current_wave = request.current_wave diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 2ebb76a97ebe..8b6c8ceb48cd 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -32,6 +32,7 @@ from vllm.v1.engine.utils import (CoreEngineActorManager, CoreEngineProcManager, launch_core_engines) from vllm.v1.executor.abstract import Executor +from vllm.v1.request import Request from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr logger = init_logger(__name__) @@ -104,7 +105,7 @@ def shutdown(self): def get_output(self) -> EngineCoreOutputs: raise NotImplementedError - def add_request(self, request: EngineCoreRequest) -> None: + def add_request(self, request: Union[EngineCoreRequest, Request]) -> None: raise NotImplementedError def profile(self, is_start: bool = True) -> None: @@ -238,7 +239,7 @@ def get_output(self) -> EngineCoreOutputs: outputs, _ = self.engine_core.step() return outputs.get(0) or EngineCoreOutputs() - def add_request(self, request: EngineCoreRequest) -> None: + def add_request(self, request: Union[EngineCoreRequest, Request]) -> None: self.engine_core.add_request(request) def abort_requests(self, request_ids: list[str]) -> None: @@ -603,7 +604,7 @@ def call_utility(self, method: str, *args) -> Any: return future.result() - def add_request(self, request: EngineCoreRequest) -> None: + def add_request(self, request: Union[EngineCoreRequest, Request]) -> None: if self.is_dp: self.engines_running = True self._send_input(EngineCoreRequestType.ADD, request) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 85f5dcb92eb4..43f11068c659 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -35,10 +35,12 @@ def __init__( lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, + current_wave: int = 0, priority: int = 0, ) -> None: self.request_id = request_id self.client_index = client_index + self.current_wave = current_wave self.priority = priority self.sampling_params = sampling_params self.pooling_params = pooling_params @@ -131,6 +133,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": sampling_params=request.sampling_params) \ if request.sampling_params else None, cache_salt=request.cache_salt, + current_wave=request.current_wave, priority=request.priority, )