Skip to content

Commit bccab82

Browse files
committed
Convert EngineCoreRequest to Request before reaching the engine core thread
Signed-off-by: Jialin Ouyang <[email protected]>
1 parent e7b2042 commit bccab82

File tree

2 files changed

+31
-21
lines changed

2 files changed

+31
-21
lines changed

vllm/v1/engine/core.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def _initialize_kv_caches(
194194
"warmup model) took %.2f seconds"), elapsed)
195195
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
196196

197-
def add_request(self, request: EngineCoreRequest):
197+
def add_request(self, request: Request):
198198
"""Add request to the scheduler."""
199199
if pooling_params := request.pooling_params:
200200
supported_pooling_tasks = (
@@ -203,27 +203,16 @@ def add_request(self, request: EngineCoreRequest):
203203
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
204204
f"Supported tasks: {supported_pooling_tasks}")
205205

206-
if request.mm_hashes is not None:
207-
# Here, if hash exists for a multimodal input, then it will be
208-
# fetched from the cache, else it will be added to the cache.
209-
# Note that the cache here is mirrored with the client cache, so
210-
# anything that has a hash must have a HIT cache entry here
211-
# as well.
212-
assert request.mm_inputs is not None
213-
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
214-
request.mm_inputs, request.mm_hashes)
215-
216-
req = Request.from_engine_core_request(request)
217-
if req.use_structured_output:
206+
if request.use_structured_output:
218207
# Start grammar compilation asynchronously
219-
self.structured_output_manager.grammar_init(req)
208+
self.structured_output_manager.grammar_init(request)
220209

221-
if req.kv_transfer_params is not None and (
210+
if request.kv_transfer_params is not None and (
222211
not self.scheduler.get_kv_connector()):
223212
logger.warning("Got kv_transfer_params, but no KVConnector found. "
224213
"Disabling KVTransfer for this request.")
225214

226-
self.scheduler.add_request(req)
215+
self.scheduler.add_request(request)
227216

228217
def abort_requests(self, request_ids: list[str]):
229218
"""Abort requests from the scheduler."""
@@ -766,10 +755,11 @@ def process_input_sockets(self, input_addresses: list[str],
766755
bytes(type_frame.buffer))
767756

768757
# Deserialize the request data.
769-
decoder = add_request_decoder if (
770-
request_type
771-
== EngineCoreRequestType.ADD) else generic_decoder
772-
request = decoder.decode(data_frames)
758+
if request_type == EngineCoreRequestType.ADD:
759+
request = add_request_decoder.decode(data_frames)
760+
request = self._post_process_add_request(request)
761+
else:
762+
request = generic_decoder.decode(data_frames)
773763

774764
# Push to input queue for core busy loop.
775765
self.input_queue.put_nowait((request_type, request))
@@ -835,6 +825,23 @@ def process_output_sockets(self, output_paths: list[str],
835825
# Limit the number of buffers to reuse.
836826
reuse_buffers.append(buffer)
837827

828+
def _post_process_add_request(self, request: EngineCoreRequest) -> Request:
829+
"""Post-processes the request before reaching to EngineCore.
830+
831+
This call would be executed in parallel with Model forward which
832+
relaxes request preparation works out from critical path."""
833+
if request.mm_hashes is not None:
834+
# Here, if hash exists for a multimodal input, then it will be
835+
# fetched from the cache, else it will be added to the cache.
836+
# Note that the cache here is mirrored with the client cache, so
837+
# anything that has a hash must have a HIT cache entry here
838+
# as well.
839+
assert request.mm_inputs is not None
840+
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
841+
request.mm_inputs, request.mm_hashes)
842+
843+
return Request.from_engine_core_request(request)
844+
838845

839846
class DPEngineCoreProc(EngineCoreProc):
840847
"""ZMQ-wrapper for running EngineCore in background process
@@ -915,7 +922,7 @@ def shutdown(self):
915922
if dp_group := getattr(self, "dp_group", None):
916923
stateless_destroy_torch_distributed_process_group(dp_group)
917924

918-
def add_request(self, request: EngineCoreRequest):
925+
def add_request(self, request: Request):
919926
if self.has_coordinator and request.current_wave != self.current_wave:
920927
if request.current_wave > self.current_wave:
921928
self.current_wave = request.current_wave

vllm/v1/request.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@ def __init__(
3535
lora_request: Optional["LoRARequest"] = None,
3636
structured_output_request: Optional["StructuredOutputRequest"] = None,
3737
cache_salt: Optional[str] = None,
38+
current_wave: int = 0,
3839
priority: int = 0,
3940
) -> None:
4041
self.request_id = request_id
4142
self.client_index = client_index
43+
self.current_wave = current_wave
4244
self.priority = priority
4345
self.sampling_params = sampling_params
4446
self.pooling_params = pooling_params
@@ -131,6 +133,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
131133
sampling_params=request.sampling_params) \
132134
if request.sampling_params else None,
133135
cache_salt=request.cache_salt,
136+
current_wave=request.current_wave,
134137
priority=request.priority,
135138
)
136139

0 commit comments

Comments
 (0)