@@ -194,7 +194,7 @@ def _initialize_kv_caches(
194
194
"warmup model) took %.2f seconds" ), elapsed )
195
195
return num_gpu_blocks , num_cpu_blocks , scheduler_kv_cache_config
196
196
197
- def add_request (self , request : EngineCoreRequest ):
197
+ def add_request (self , request : Request ):
198
198
"""Add request to the scheduler."""
199
199
if pooling_params := request .pooling_params :
200
200
supported_pooling_tasks = (
@@ -203,27 +203,16 @@ def add_request(self, request: EngineCoreRequest):
203
203
raise ValueError (f"Unsupported task: { pooling_params .task !r} "
204
204
f"Supported tasks: { supported_pooling_tasks } " )
205
205
206
- if request .mm_hashes is not None :
207
- # Here, if hash exists for a multimodal input, then it will be
208
- # fetched from the cache, else it will be added to the cache.
209
- # Note that the cache here is mirrored with the client cache, so
210
- # anything that has a hash must have a HIT cache entry here
211
- # as well.
212
- assert request .mm_inputs is not None
213
- request .mm_inputs = self .mm_input_cache_server .get_and_update_p1 (
214
- request .mm_inputs , request .mm_hashes )
215
-
216
- req = Request .from_engine_core_request (request )
217
- if req .use_structured_output :
206
+ if request .use_structured_output :
218
207
# Start grammar compilation asynchronously
219
- self .structured_output_manager .grammar_init (req )
208
+ self .structured_output_manager .grammar_init (request )
220
209
221
- if req .kv_transfer_params is not None and (
210
+ if request .kv_transfer_params is not None and (
222
211
not self .scheduler .get_kv_connector ()):
223
212
logger .warning ("Got kv_transfer_params, but no KVConnector found. "
224
213
"Disabling KVTransfer for this request." )
225
214
226
- self .scheduler .add_request (req )
215
+ self .scheduler .add_request (request )
227
216
228
217
def abort_requests (self , request_ids : list [str ]):
229
218
"""Abort requests from the scheduler."""
@@ -766,10 +755,11 @@ def process_input_sockets(self, input_addresses: list[str],
766
755
bytes (type_frame .buffer ))
767
756
768
757
# Deserialize the request data.
769
- decoder = add_request_decoder if (
770
- request_type
771
- == EngineCoreRequestType .ADD ) else generic_decoder
772
- request = decoder .decode (data_frames )
758
+ if request_type == EngineCoreRequestType .ADD :
759
+ request = add_request_decoder .decode (data_frames )
760
+ request = self ._post_process_add_request (request )
761
+ else :
762
+ request = generic_decoder .decode (data_frames )
773
763
774
764
# Push to input queue for core busy loop.
775
765
self .input_queue .put_nowait ((request_type , request ))
@@ -835,6 +825,23 @@ def process_output_sockets(self, output_paths: list[str],
835
825
# Limit the number of buffers to reuse.
836
826
reuse_buffers .append (buffer )
837
827
828
+ def _post_process_add_request (self , request : EngineCoreRequest ) -> Request :
829
+ """Post-processes the request before reaching to EngineCore.
830
+
831
+ This call would be executed in parallel with Model forward which
832
+ relaxes request preparation works out from critical path."""
833
+ if request .mm_hashes is not None :
834
+ # Here, if hash exists for a multimodal input, then it will be
835
+ # fetched from the cache, else it will be added to the cache.
836
+ # Note that the cache here is mirrored with the client cache, so
837
+ # anything that has a hash must have a HIT cache entry here
838
+ # as well.
839
+ assert request .mm_inputs is not None
840
+ request .mm_inputs = self .mm_input_cache_server .get_and_update_p1 (
841
+ request .mm_inputs , request .mm_hashes )
842
+
843
+ return Request .from_engine_core_request (request )
844
+
838
845
839
846
class DPEngineCoreProc (EngineCoreProc ):
840
847
"""ZMQ-wrapper for running EngineCore in background process
@@ -915,7 +922,7 @@ def shutdown(self):
915
922
if dp_group := getattr (self , "dp_group" , None ):
916
923
stateless_destroy_torch_distributed_process_group (dp_group )
917
924
918
- def add_request (self , request : EngineCoreRequest ):
925
+ def add_request (self , request : Request ):
919
926
if self .has_coordinator and request .current_wave != self .current_wave :
920
927
if request .current_wave > self .current_wave :
921
928
self .current_wave = request .current_wave
0 commit comments