Skip to content

Commit 9e39c1f

Browse files
committed
optimize code
Signed-off-by: wuhang <wuhang6@huawei.com>
1 parent 0e1105f commit 9e39c1f

File tree

6 files changed

+341
-533
lines changed

6 files changed

+341
-533
lines changed

vllm_omni/entrypoints/async_omni.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,13 @@ def _weak_close_cleanup_async(stage_list, stage_in_queues, stage_out_queues, ray
5050
q.put_nowait(SHUTDOWN_TASK)
5151
except Exception as e:
5252
logger.warning(f"Failed to send shutdown signal to stage input queue: {e}")
53-
try:
54-
close_fn = getattr(q, "close", None)
55-
if callable(close_fn):
56-
close_fn()
57-
except Exception:
58-
pass
53+
close_fn = getattr(q, "close", None)
54+
if callable(close_fn):
55+
close_fn()
5956
for q in stage_out_queues:
60-
try:
61-
close_fn = getattr(q, "close", None)
62-
if callable(close_fn):
63-
close_fn()
64-
except Exception:
65-
pass
57+
close_fn = getattr(q, "close", None)
58+
if callable(close_fn):
59+
close_fn()
6660
for stage in stage_list:
6761
try:
6862
stage.stop_stage_worker()
@@ -73,10 +67,7 @@ def _weak_close_cleanup_async(stage_list, stage_in_queues, stage_out_queues, ray
7367
if output_handler is not None:
7468
output_handler.cancel()
7569
if zmq_ctx is not None:
76-
try:
77-
zmq_ctx.term()
78-
except Exception:
79-
pass
70+
zmq_ctx.term()
8071

8172

8273
class AsyncOmni(OmniBase):
@@ -293,8 +284,7 @@ async def generate(
293284
async with self._pause_cond:
294285
await self._pause_cond.wait_for(lambda: not self._paused)
295286

296-
logger.info(f"[{self._name}] generate() called")
297-
logger.info(f"====== {self.stage_list=}")
287+
logger.debug(f"[{self._name}] generate() called")
298288
try:
299289
# Start output handler on the first call to generate()
300290
self._run_output_handler()

vllm_omni/entrypoints/cli/serve.py

Lines changed: 49 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@
1010
import signal
1111
from typing import Any
1212

13+
import msgspec.msgpack
1314
import uvloop
1415
import zmq
15-
from omegaconf import OmegaConf
1616
from vllm.entrypoints.cli.types import CLISubcommand
1717
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
1818
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
1919
from vllm.logger import init_logger
2020
from vllm.utils.argparse_utils import FlexibleArgumentParser
21+
from vllm.utils.network_utils import make_zmq_socket
22+
from vllm.v1.utils import get_engine_client_zmq_addr
2123

2224
from vllm_omni.distributed.omni_connectors import (
2325
get_connectors_config_for_stage,
@@ -27,13 +29,14 @@
2729
from vllm_omni.entrypoints.omni_stage import OmniStage
2830
from vllm_omni.entrypoints.openai.api_server import omni_run_server
2931
from vllm_omni.entrypoints.utils import (
30-
load_stage_configs_from_model,
31-
load_stage_configs_from_yaml,
32-
resolve_model_config_path,
32+
build_base_engine_args,
33+
load_and_resolve_stage_configs,
3334
)
3435

3536
logger = init_logger(__name__)
3637

38+
HANDSHAKE_TIMEOUT_MINS = 5
39+
3740
DESCRIPTION = """Launch a local OpenAI-compatible API server to serve Omni models
3841
via HTTP. Supports both multi-stage LLM models and diffusion models.
3942
@@ -296,47 +299,21 @@ def _create_default_diffusion_stage_cfg(args: argparse.Namespace) -> list[dict[s
296299
def run_headless(args: argparse.Namespace) -> None:
297300
if args.api_server_count > 1:
298301
raise ValueError("api_server_count can't be set in headless mode")
299-
if getattr(args, "worker_backend", "multi_process") != "multi_process":
302+
if args.worker_backend != "multi_process":
300303
raise ValueError("headless mode requires worker_backend=multi_process")
301304

302-
model = getattr(args, "model", None)
303-
if not model:
304-
raise ValueError("model must be specified in headless mode")
305-
model = omni_snapshot_download(model)
306-
307-
tokenizer = getattr(args, "tokenizer", None)
308-
base_engine_args = {"tokenizer": tokenizer} if tokenizer is not None else None
309-
310-
parallel_keys = [
311-
"tensor_parallel_size",
312-
"pipeline_parallel_size",
313-
"data_parallel_size",
314-
"data_parallel_size_local",
315-
"data_parallel_backend",
316-
"distributed_executor_backend",
317-
]
318-
parallel_overrides = {
319-
k: getattr(args, k) for k in parallel_keys if hasattr(args, k) and getattr(args, k) is not None
320-
}
321-
if parallel_overrides:
322-
base_engine_args = base_engine_args or {}
323-
base_engine_args.update(parallel_overrides)
324-
325-
stage_configs_path = getattr(args, "stage_configs_path", None)
326-
if stage_configs_path is None:
327-
config_path = resolve_model_config_path(model)
328-
stage_configs = load_stage_configs_from_model(model, base_engine_args=base_engine_args)
329-
if not stage_configs:
330-
default_stage_cfg = _create_default_diffusion_stage_cfg(args)
331-
stage_configs = OmegaConf.create(default_stage_cfg)
332-
else:
333-
config_path = stage_configs_path
334-
stage_configs = load_stage_configs_from_yaml(stage_configs_path, base_engine_args=base_engine_args)
335-
336-
if not stage_configs:
337-
raise ValueError("No stage configs found; provide --stage-configs-path or a supported model.")
338-
339-
single_stage_id = getattr(args, "stage_id", None)
305+
model = omni_snapshot_download(args.model)
306+
307+
base_engine_args = build_base_engine_args(args)
308+
stage_configs_path = args.stage_configs_path
309+
config_path, stage_configs = load_and_resolve_stage_configs(
310+
model,
311+
stage_configs_path,
312+
base_engine_args,
313+
default_stage_cfg_factory=lambda: _create_default_diffusion_stage_cfg(args),
314+
)
315+
316+
single_stage_id = args.stage_id
340317
if single_stage_id is None:
341318
if len(stage_configs) != 1:
342319
raise ValueError("--stage-id is required in headless mode for multi-stage configs")
@@ -350,47 +327,42 @@ def run_headless(args: argparse.Namespace) -> None:
350327
if stage_config is None:
351328
raise ValueError(f"No stage matches stage_id={single_stage_id}.")
352329

330+
# TODO(wuhang): Support connectors config by cli
353331
transfer_config = load_omni_transfer_config(config_path, default_shm_threshold=args.shm_threshold_bytes)
354332
connectors_config = get_connectors_config_for_stage(transfer_config, single_stage_id)
355333

356-
omni_master_address = getattr(args, "omni_master_address", None) or "127.0.0.1"
357-
omni_master_port = int(getattr(args, "omni_master_port", 5555) or 5555)
334+
omni_master_address = args.omni_master_address
335+
omni_master_port = args.omni_master_port
358336

359337
# Perform handshake with orchestrator to get dynamically allocated endpoints
360-
zmq_ctx = zmq.Context()
361-
handshake_socket = zmq_ctx.socket(zmq.REQ)
362-
handshake_socket.linger = 0
363-
handshake_endpoint = f"tcp://{omni_master_address}:{omni_master_port}"
364-
365-
try:
366-
handshake_socket.connect(handshake_endpoint)
367-
handshake_msg = {"type": "handshake", "stage_id": single_stage_id}
368-
handshake_socket.send_pyobj(handshake_msg)
369-
370-
# Wait for response with timeout
371-
if handshake_socket.poll(timeout=10000): # 10 second timeout
372-
response = handshake_socket.recv_pyobj()
373-
if not response.get("ok", False):
374-
error_msg = response.get("error", "unknown error")
338+
with zmq.Context() as zmq_ctx:
339+
handshake_endpoint = get_engine_client_zmq_addr(
340+
local_only=False, host=omni_master_address, port=omni_master_port
341+
)
342+
343+
with make_zmq_socket(zmq_ctx, handshake_endpoint, zmq.REQ, bind=False, linger=5000) as handshake_socket:
344+
# TODO(wuhang): Define protocol in python dataclass.
345+
handshake_msg = {"type": "handshake", "stage_id": single_stage_id}
346+
handshake_socket.send(msgspec.msgpack.encode(handshake_msg))
347+
348+
# Wait for response with timeout
349+
if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
350+
raise RuntimeError(
351+
f"Handshake timeout ({HANDSHAKE_TIMEOUT_MINS} minutes) for stage-{single_stage_id} "
352+
f"at {handshake_endpoint}"
353+
)
354+
355+
response = msgspec.msgpack.decode(handshake_socket.recv())
356+
if not response["ok"]:
357+
error_msg = response["error"]
375358
raise RuntimeError(f"Handshake failed for stage-{single_stage_id}: {error_msg}")
376359

377-
in_q_spec = response.get("in_spec")
378-
out_q_spec = response.get("out_spec")
379-
380-
if in_q_spec is None or out_q_spec is None:
381-
raise RuntimeError(f"Handshake response missing specs for stage-{single_stage_id}")
360+
in_endpoint, out_endpoint = response["in_endpoint"], response["out_endpoint"]
382361

383362
logger.info(
384363
f"[Headless] Stage-{single_stage_id} received endpoints via handshake: "
385-
f"in={in_q_spec.endpoint}, out={out_q_spec.endpoint}"
364+
f"in={in_endpoint}, out={out_endpoint}"
386365
)
387-
else:
388-
raise TimeoutError(f"Handshake timeout for stage-{single_stage_id} at {handshake_endpoint}")
389-
390-
finally:
391-
handshake_socket.close(0)
392-
in_q = None
393-
out_q = None
394366

395367
shutdown_requested = False
396368

@@ -404,18 +376,17 @@ def signal_handler(signum, frame):
404376
signal.signal(signal.SIGTERM, signal_handler)
405377
signal.signal(signal.SIGINT, signal_handler)
406378

407-
stage = OmniStage(stage_config, stage_init_timeout=int(getattr(args, "stage_init_timeout", 300)))
408-
stage.set_zmq_master(omni_master_address, omni_master_port)
409-
stage.attach_queues(in_q, out_q, in_q_spec=in_q_spec, out_q_spec=out_q_spec)
379+
stage = OmniStage(stage_config, stage_init_timeout=args.stage_init_timeout)
380+
stage.attach_queues(in_endpoint, out_endpoint)
410381

411382
old_env = os.environ.get("VLLM_LOGGING_PREFIX")
412383
os.environ["VLLM_LOGGING_PREFIX"] = f"[Stage-{single_stage_id}] {'' if old_env is None else old_env}"
413384
try:
414385
stage.init_stage_worker(
415386
model,
416387
is_async=True,
417-
shm_threshold_bytes=int(getattr(args, "shm_threshold_bytes", 65536)),
418-
batch_timeout=int(getattr(args, "batch_timeout", 10)),
388+
shm_threshold_bytes=int(args.shm_threshold_bytes),
389+
batch_timeout=int(args.batch_timeout),
419390
connectors_config=connectors_config,
420391
worker_backend="multi_process",
421392
ignore_runtime_config=True,
@@ -424,14 +395,6 @@ def signal_handler(signum, frame):
424395
stage._proc.join()
425396
finally:
426397
stage.stop_stage_worker()
427-
try:
428-
zmq_ctx.term()
429-
except Exception:
430-
pass
431-
if old_env is None:
432-
os.environ.pop("VLLM_LOGGING_PREFIX", None)
433-
else:
434-
os.environ["VLLM_LOGGING_PREFIX"] = old_env
435398

436399

437400
def cmd_init() -> list[CLISubcommand]:

0 commit comments

Comments
 (0)