1010import signal
1111from typing import Any
1212
13+ import msgspec .msgpack
1314import uvloop
1415import zmq
15- from omegaconf import OmegaConf
1616from vllm .entrypoints .cli .types import CLISubcommand
1717from vllm .entrypoints .openai .cli_args import make_arg_parser , validate_parsed_serve_args
1818from vllm .entrypoints .utils import VLLM_SUBCMD_PARSER_EPILOG
1919from vllm .logger import init_logger
2020from vllm .utils .argparse_utils import FlexibleArgumentParser
21+ from vllm .utils .network_utils import make_zmq_socket
22+ from vllm .v1 .utils import get_engine_client_zmq_addr
2123
2224from vllm_omni .distributed .omni_connectors import (
2325 get_connectors_config_for_stage ,
2729from vllm_omni .entrypoints .omni_stage import OmniStage
2830from vllm_omni .entrypoints .openai .api_server import omni_run_server
2931from vllm_omni .entrypoints .utils import (
30- load_stage_configs_from_model ,
31- load_stage_configs_from_yaml ,
32- resolve_model_config_path ,
32+ build_base_engine_args ,
33+ load_and_resolve_stage_configs ,
3334)
3435
3536logger = init_logger (__name__ )
3637
38+ HANDSHAKE_TIMEOUT_MINS = 5
39+
3740DESCRIPTION = """Launch a local OpenAI-compatible API server to serve Omni models
3841via HTTP. Supports both multi-stage LLM models and diffusion models.
3942
@@ -296,47 +299,21 @@ def _create_default_diffusion_stage_cfg(args: argparse.Namespace) -> list[dict[s
296299def run_headless (args : argparse .Namespace ) -> None :
297300 if args .api_server_count > 1 :
298301 raise ValueError ("api_server_count can't be set in headless mode" )
299- if getattr ( args , " worker_backend" , "multi_process" ) != "multi_process" :
302+ if args . worker_backend != "multi_process" :
300303 raise ValueError ("headless mode requires worker_backend=multi_process" )
301304
302- model = getattr (args , "model" , None )
303- if not model :
304- raise ValueError ("model must be specified in headless mode" )
305- model = omni_snapshot_download (model )
306-
307- tokenizer = getattr (args , "tokenizer" , None )
308- base_engine_args = {"tokenizer" : tokenizer } if tokenizer is not None else None
309-
310- parallel_keys = [
311- "tensor_parallel_size" ,
312- "pipeline_parallel_size" ,
313- "data_parallel_size" ,
314- "data_parallel_size_local" ,
315- "data_parallel_backend" ,
316- "distributed_executor_backend" ,
317- ]
318- parallel_overrides = {
319- k : getattr (args , k ) for k in parallel_keys if hasattr (args , k ) and getattr (args , k ) is not None
320- }
321- if parallel_overrides :
322- base_engine_args = base_engine_args or {}
323- base_engine_args .update (parallel_overrides )
324-
325- stage_configs_path = getattr (args , "stage_configs_path" , None )
326- if stage_configs_path is None :
327- config_path = resolve_model_config_path (model )
328- stage_configs = load_stage_configs_from_model (model , base_engine_args = base_engine_args )
329- if not stage_configs :
330- default_stage_cfg = _create_default_diffusion_stage_cfg (args )
331- stage_configs = OmegaConf .create (default_stage_cfg )
332- else :
333- config_path = stage_configs_path
334- stage_configs = load_stage_configs_from_yaml (stage_configs_path , base_engine_args = base_engine_args )
335-
336- if not stage_configs :
337- raise ValueError ("No stage configs found; provide --stage-configs-path or a supported model." )
338-
339- single_stage_id = getattr (args , "stage_id" , None )
305+ model = omni_snapshot_download (args .model )
306+
307+ base_engine_args = build_base_engine_args (args )
308+ stage_configs_path = args .stage_configs_path
309+ config_path , stage_configs = load_and_resolve_stage_configs (
310+ model ,
311+ stage_configs_path ,
312+ base_engine_args ,
313+ default_stage_cfg_factory = lambda : _create_default_diffusion_stage_cfg (args ),
314+ )
315+
316+ single_stage_id = args .stage_id
340317 if single_stage_id is None :
341318 if len (stage_configs ) != 1 :
342319 raise ValueError ("--stage-id is required in headless mode for multi-stage configs" )
@@ -350,47 +327,42 @@ def run_headless(args: argparse.Namespace) -> None:
350327 if stage_config is None :
351328 raise ValueError (f"No stage matches stage_id={ single_stage_id } ." )
352329
330+ # TODO(wuhang): Support connectors config by cli
353331 transfer_config = load_omni_transfer_config (config_path , default_shm_threshold = args .shm_threshold_bytes )
354332 connectors_config = get_connectors_config_for_stage (transfer_config , single_stage_id )
355333
356- omni_master_address = getattr ( args , "omni_master_address" , None ) or "127.0.0.1"
357- omni_master_port = int ( getattr ( args , " omni_master_port" , 5555 ) or 5555 )
334+ omni_master_address = args . omni_master_address
335+ omni_master_port = args . omni_master_port
358336
359337 # Perform handshake with orchestrator to get dynamically allocated endpoints
360- zmq_ctx = zmq .Context ()
361- handshake_socket = zmq_ctx .socket (zmq .REQ )
362- handshake_socket .linger = 0
363- handshake_endpoint = f"tcp://{ omni_master_address } :{ omni_master_port } "
364-
365- try :
366- handshake_socket .connect (handshake_endpoint )
367- handshake_msg = {"type" : "handshake" , "stage_id" : single_stage_id }
368- handshake_socket .send_pyobj (handshake_msg )
369-
370- # Wait for response with timeout
371- if handshake_socket .poll (timeout = 10000 ): # 10 second timeout
372- response = handshake_socket .recv_pyobj ()
373- if not response .get ("ok" , False ):
374- error_msg = response .get ("error" , "unknown error" )
338+ with zmq .Context () as zmq_ctx :
339+ handshake_endpoint = get_engine_client_zmq_addr (
340+ local_only = False , host = omni_master_address , port = omni_master_port
341+ )
342+
343+ with make_zmq_socket (zmq_ctx , handshake_endpoint , zmq .REQ , bind = False , linger = 5000 ) as handshake_socket :
344+ # TODO(wuhang): Define protocol in python dataclass.
345+ handshake_msg = {"type" : "handshake" , "stage_id" : single_stage_id }
346+ handshake_socket .send (msgspec .msgpack .encode (handshake_msg ))
347+
348+ # Wait for response with timeout
349+ if not handshake_socket .poll (timeout = HANDSHAKE_TIMEOUT_MINS * 60_000 ):
350+ raise RuntimeError (
351+ f"Handshake timeout ({ HANDSHAKE_TIMEOUT_MINS } minutes) for stage-{ single_stage_id } "
352+ f"at { handshake_endpoint } "
353+ )
354+
355+ response = msgspec .msgpack .decode (handshake_socket .recv ())
356+ if not response ["ok" ]:
357+ error_msg = response ["error" ]
375358 raise RuntimeError (f"Handshake failed for stage-{ single_stage_id } : { error_msg } " )
376359
377- in_q_spec = response .get ("in_spec" )
378- out_q_spec = response .get ("out_spec" )
379-
380- if in_q_spec is None or out_q_spec is None :
381- raise RuntimeError (f"Handshake response missing specs for stage-{ single_stage_id } " )
360+ in_endpoint , out_endpoint = response ["in_endpoint" ], response ["out_endpoint" ]
382361
383362 logger .info (
384363 f"[Headless] Stage-{ single_stage_id } received endpoints via handshake: "
385- f"in={ in_q_spec . endpoint } , out={ out_q_spec . endpoint } "
364+ f"in={ in_endpoint } , out={ out_endpoint } "
386365 )
387- else :
388- raise TimeoutError (f"Handshake timeout for stage-{ single_stage_id } at { handshake_endpoint } " )
389-
390- finally :
391- handshake_socket .close (0 )
392- in_q = None
393- out_q = None
394366
395367 shutdown_requested = False
396368
@@ -404,18 +376,17 @@ def signal_handler(signum, frame):
404376 signal .signal (signal .SIGTERM , signal_handler )
405377 signal .signal (signal .SIGINT , signal_handler )
406378
407- stage = OmniStage (stage_config , stage_init_timeout = int (getattr (args , "stage_init_timeout" , 300 )))
408- stage .set_zmq_master (omni_master_address , omni_master_port )
409- stage .attach_queues (in_q , out_q , in_q_spec = in_q_spec , out_q_spec = out_q_spec )
379+ stage = OmniStage (stage_config , stage_init_timeout = args .stage_init_timeout )
380+ stage .attach_queues (in_endpoint , out_endpoint )
410381
411382 old_env = os .environ .get ("VLLM_LOGGING_PREFIX" )
412383 os .environ ["VLLM_LOGGING_PREFIX" ] = f"[Stage-{ single_stage_id } ] { '' if old_env is None else old_env } "
413384 try :
414385 stage .init_stage_worker (
415386 model ,
416387 is_async = True ,
417- shm_threshold_bytes = int (getattr ( args , " shm_threshold_bytes" , 65536 ) ),
418- batch_timeout = int (getattr ( args , " batch_timeout" , 10 ) ),
388+ shm_threshold_bytes = int (args . shm_threshold_bytes ),
389+ batch_timeout = int (args . batch_timeout ),
419390 connectors_config = connectors_config ,
420391 worker_backend = "multi_process" ,
421392 ignore_runtime_config = True ,
@@ -424,14 +395,6 @@ def signal_handler(signum, frame):
424395 stage ._proc .join ()
425396 finally :
426397 stage .stop_stage_worker ()
427- try :
428- zmq_ctx .term ()
429- except Exception :
430- pass
431- if old_env is None :
432- os .environ .pop ("VLLM_LOGGING_PREFIX" , None )
433- else :
434- os .environ ["VLLM_LOGGING_PREFIX" ] = old_env
435398
436399
437400def cmd_init () -> list [CLISubcommand ]:
0 commit comments