88import uuid
99import weakref
1010from abc import ABC , abstractmethod
11- from collections .abc import Awaitable
11+ from collections .abc import Awaitable , Sequence
1212from concurrent .futures import Future
1313from dataclasses import dataclass , field
1414from threading import Thread
3535
3636_R = TypeVar ('_R' ) # Return type for collective_rpc
3737
38- STARTUP_POLL_PERIOD_MS = 10000
39-
4038
4139class EngineCoreClient (ABC ):
4240 """
@@ -263,13 +261,15 @@ def __init__(
263261 vllm_config : VllmConfig ,
264262 executor_class : type [Executor ],
265263 log_stats : bool ,
266- input_path : str ,
264+ ctx : Union [ zmq . Context , zmq . asyncio . Context ] ,
267265 output_path : str ,
268266 index : int = 0 ,
269267 local_dp_rank : int = 0 ,
270268 ):
271- self .index = index
272- self .identity = index .to_bytes (length = 2 , byteorder = "little" )
269+ # Paths and sockets for IPC.
270+ input_path = get_open_zmq_ipc_path ()
271+ self .input_socket = make_zmq_socket (ctx , input_path ,
272+ zmq .constants .PUSH )
273273 try :
274274 # Start EngineCore in background process.
275275 self .proc_handle = BackgroundProcHandle (
@@ -291,9 +291,14 @@ def __init__(
291291 # Ensure socket is closed if process fails to start.
292292 self .close ()
293293
294+ def send_multipart (self , msg_parts : Sequence ):
295+ return self .input_socket .send_multipart (msg_parts , copy = False )
296+
294297 def close (self ):
295298 if proc_handle := getattr (self , "proc_handle" , None ):
296299 proc_handle .shutdown ()
300+ if socket := getattr (self , "input_socket" , None ):
301+ socket .close (linger = 0 )
297302
298303
299304@dataclass
@@ -304,7 +309,6 @@ class BackgroundResources:
304309 ctx : Union [zmq .Context ]
305310 core_engines : list [CoreEngine ] = field (default_factory = list )
306311 output_socket : Optional [Union [zmq .Socket , zmq .asyncio .Socket ]] = None
307- input_socket : Optional [Union [zmq .Socket , zmq .asyncio .Socket ]] = None
308312 shutdown_path : Optional [str ] = None
309313
310314 def __call__ (self ):
@@ -317,8 +321,6 @@ def __call__(self):
317321 # aren't explicitly closed first.
318322 if self .output_socket is not None :
319323 self .output_socket .close (linger = 0 )
320- if self .input_socket is not None :
321- self .input_socket .close (linger = 0 )
322324 if self .shutdown_path is not None :
323325 # We must ensure that the sync output socket is
324326 # closed cleanly in its own thread.
@@ -385,51 +387,21 @@ def sigusr1_handler(signum, frame):
385387
386388 # Paths and sockets for IPC.
387389 self .output_path = get_open_zmq_ipc_path ()
388- input_path = get_open_zmq_ipc_path ()
389- self .input_socket = make_zmq_socket (self .ctx ,
390- input_path ,
391- zmq .ROUTER ,
392- bind = True )
393- self .resources .input_socket = self .input_socket
394390
395391 new_core_engine = lambda index , local_dp_rank = None : CoreEngine (
396- vllm_config , executor_class , log_stats , input_path , self .
397- output_path , index , local_dp_rank )
392+ vllm_config , executor_class , log_stats , self . ctx , self .output_path ,
393+ index , local_dp_rank )
398394
399395 # Start engine core process(es).
400396 self ._init_core_engines (vllm_config , new_core_engine ,
401397 self .resources .core_engines )
402398
403399 # Wait for engine core process(es) to start.
404- self ._wait_for_engine_startup ()
400+ for engine in self .resources .core_engines :
401+ engine .proc_handle .wait_for_startup ()
405402
406403 self .utility_results : dict [int , AnyFuture ] = {}
407404
408- def _wait_for_engine_startup (self ):
409- # Get a sync handle to the socket which can be sync or async.
410- sync_input_socket = zmq .Socket .shadow (self .input_socket )
411-
412- # Wait for engine core process(es) to send ready messages.
413- identities = set (eng .index for eng in self .resources .core_engines )
414- while identities :
415- while not sync_input_socket .poll (timeout = STARTUP_POLL_PERIOD_MS ):
416- logger .info ("Waiting for %d core engine proc(s) to start: %s" ,
417- len (identities ), identities )
418- eng_id_bytes , msg = sync_input_socket .recv_multipart ()
419- eng_id = int .from_bytes (eng_id_bytes , byteorder = "little" )
420- if eng_id not in identities :
421- raise RuntimeError (f"Unexpected or duplicate engine: { eng_id } " )
422- if msg != b'READY' :
423- raise RuntimeError (f"Engine { eng_id } failed: { msg .decode ()} " )
424- logger .info ("Core engine process %d ready." , eng_id )
425- identities .discard (eng_id )
426-
427- # Double check that the process are running.
428- for engine in self .resources .core_engines :
429- proc = engine .proc_handle .proc
430- if proc .exitcode is not None :
431- raise RuntimeError (f"Engine proc { proc .name } not running" )
432-
433405 def _init_core_engines (
434406 self ,
435407 vllm_config : VllmConfig ,
@@ -522,10 +494,9 @@ def get_output(self) -> EngineCoreOutputs:
522494 return self .outputs_queue .get ()
523495
524496 def _send_input (self , request_type : EngineCoreRequestType , request : Any ):
525- # (Identity, RequestType, SerializedRequest)
526- msg = (self .core_engine .identity , request_type .value ,
527- self .encoder .encode (request ))
528- self .input_socket .send_multipart (msg , copy = False )
497+ # (RequestType, SerializedRequest)
498+ msg = (request_type .value , self .encoder .encode (request ))
499+ self .core_engine .send_multipart (msg )
529500
530501 def call_utility (self , method : str , * args ) -> Any :
531502 call_id = uuid .uuid1 ().int >> 64
@@ -654,34 +625,30 @@ async def get_output_async(self) -> EngineCoreOutputs:
654625 assert self .outputs_queue is not None
655626 return await self .outputs_queue .get ()
656627
657- def _send_input (self ,
658- request_type : EngineCoreRequestType ,
659- request : Any ,
660- engine : Optional [CoreEngine ] = None ) -> Awaitable [None ]:
661- if engine is None :
662- engine = self .core_engine
663-
664- message = (request_type .value , self .encoder .encode (request ))
665- return self ._send_input_message (message , engine )
628+ async def _send_input (self , request_type : EngineCoreRequestType ,
629+ request : Any ) -> None :
630+ await self .core_engine .send_multipart (
631+ (request_type .value , self .encoder .encode (request )))
666632
667- def _send_input_message (self , message : tuple [bytes , bytes ],
668- engine : CoreEngine ) -> Awaitable [None ]:
669- message = (engine .identity , ) + message # type: ignore[assignment]
670- return self .input_socket .send_multipart (message , copy = False )
633+ self ._ensure_output_queue_task ()
671634
672635 async def call_utility_async (self , method : str , * args ) -> Any :
673636 return await self ._call_utility_async (method ,
674637 * args ,
675638 engine = self .core_engine )
676639
677- async def _call_utility_async (self , method : str , * args ,
678- engine : CoreEngine ) -> Any :
640+ async def _call_utility_async (
641+ self ,
642+ method : str ,
643+ * args ,
644+ engine : CoreEngine ,
645+ ) -> Any :
679646 call_id = uuid .uuid1 ().int >> 64
680647 future = asyncio .get_running_loop ().create_future ()
681648 self .utility_results [call_id ] = future
682649 message = (EngineCoreRequestType .UTILITY .value ,
683650 self .encoder .encode ((call_id , method , args )))
684- await self . _send_input_message (message , engine )
651+ await engine . send_multipart (message )
685652 self ._ensure_output_queue_task ()
686653 return await future
687654
@@ -690,7 +657,6 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
690657 # tokenized.
691658 request .prompt = None
692659 await self ._send_input (EngineCoreRequestType .ADD , request )
693- self ._ensure_output_queue_task ()
694660
695661 async def abort_requests_async (self , request_ids : list [str ]) -> None :
696662 if len (request_ids ) > 0 :
@@ -795,15 +761,15 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
795761 self .reqs_in_flight [request .request_id ] = chosen_engine
796762 chosen_engine .num_reqs_in_flight += 1
797763 if self .num_engines_running >= len (self .core_engines ):
798- await self . _send_input_message (msg , chosen_engine )
764+ await chosen_engine . send_multipart (msg )
799765 else :
800766 # Send request to chosen engine and dp start loop
801767 # control message to all other engines.
802768 self .num_engines_running += len (self .core_engines )
803769 await asyncio .gather (* [
804- self . _send_input_message (
805- msg if engine is chosen_engine else self .start_dp_msg ,
806- engine ) for engine in self .core_engines
770+ engine . send_multipart ( msg if engine is
771+ chosen_engine else self .start_dp_msg )
772+ for engine in self .core_engines
807773 ])
808774
809775 self ._ensure_output_queue_task ()
@@ -828,7 +794,7 @@ async def process_engine_outputs(self: "DPAsyncMPClient",
828794 # sure to start the other engines:
829795 self .num_engines_running = len (self .core_engines )
830796 coros = [
831- self . _send_input_message (self .start_dp_msg , engine )
797+ engine . send_multipart (self .start_dp_msg )
832798 for engine in self .core_engines
833799 if not engine .num_reqs_in_flight
834800 ]
@@ -854,5 +820,5 @@ async def abort_requests_async(self, request_ids: list[str]) -> None:
854820
855821 async def _abort_requests (self , request_ids : list [str ],
856822 engine : CoreEngine ) -> None :
857- await self . _send_input ( EngineCoreRequestType .ABORT , request_ids ,
858- engine )
823+ await engine . send_multipart (( EngineCoreRequestType .ABORT . value ,
824+ self . encoder . encode ( request_ids )) )
0 commit comments