Skip to content

Commit a58d38e

Browse files
committed
Integrate BackgroundResources into Omni and AsyncOmni
Signed-off-by: Daniel Huang <daniel1.huang@intel.com>
1 parent 782eba2 commit a58d38e

File tree

2 files changed

+91
-184
lines changed

2 files changed

+91
-184
lines changed

vllm_omni/entrypoints/async_omni.py

Lines changed: 25 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import asyncio
44
import copy
55
import time
6-
import weakref
76
from collections.abc import AsyncGenerator, Iterable, Sequence
87
from typing import Any
98

@@ -18,12 +17,11 @@
1817
from vllm_omni.config import OmniModelConfig
1918
from vllm_omni.diffusion.data import DiffusionParallelConfig
2019
from vllm_omni.distributed.omni_connectors.adapter import compute_talker_prompt_ids_length, try_send_via_connector
21-
from vllm_omni.distributed.ray_utils.utils import try_close_ray
2220
from vllm_omni.engine.input_processor import OmniInputProcessor
2321
from vllm_omni.entrypoints.client_request_state import ClientRequestState
2422
from vllm_omni.entrypoints.omni import OmniBase
2523
from vllm_omni.entrypoints.omni_stage import OmniStage
26-
from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK, OmniStageTaskType
24+
from vllm_omni.entrypoints.stage_utils import OmniStageTaskType
2725
from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc as _load
2826
from vllm_omni.entrypoints.utils import (
2927
get_final_stage_id_for_e2e,
@@ -38,34 +36,6 @@
3836
logger = init_logger(__name__)
3937

4038

41-
def _weak_close_cleanup_async(stage_list, stage_in_queues, stage_out_queues, ray_pg, output_handler, zmq_ctx=None):
42-
"""Weak reference cleanup function for AsyncOmni instances."""
43-
if stage_list:
44-
for q in stage_in_queues:
45-
try:
46-
q.put_nowait(SHUTDOWN_TASK)
47-
except Exception as e:
48-
logger.warning(f"Failed to send shutdown signal to stage input queue: {e}")
49-
close_fn = getattr(q, "close", None)
50-
if callable(close_fn):
51-
close_fn()
52-
for q in stage_out_queues:
53-
close_fn = getattr(q, "close", None)
54-
if callable(close_fn):
55-
close_fn()
56-
for stage in stage_list:
57-
try:
58-
stage.stop_stage_worker()
59-
except Exception as e:
60-
logger.warning(f"Failed to stop stage worker: {e}")
61-
try_close_ray(ray_pg)
62-
# Cancel output handler
63-
if output_handler is not None:
64-
output_handler.cancel()
65-
if zmq_ctx is not None:
66-
zmq_ctx.term()
67-
68-
6939
class AsyncOmni(OmniBase):
7040
"""Asynchronous unified entry point supporting multi-stage pipelines for LLM and Diffusion models.
7141
@@ -107,22 +77,9 @@ def __init__(self, model: str, **kwargs: dict[str, Any]) -> None:
10777

10878
# Request state tracking
10979
self.request_states: dict[str, ClientRequestState] = {}
110-
self.output_handler: asyncio.Task | None = None
11180

11281
super().__init__(model, **kwargs)
11382

114-
# Register weak reference cleanup (called on garbage collection)
115-
self._weak_finalizer = weakref.finalize(
116-
self,
117-
_weak_close_cleanup_async,
118-
self.stage_list,
119-
self._stage_in_queues,
120-
self._stage_out_queues,
121-
self._ray_pg,
122-
self.output_handler,
123-
self._zmq_ctx,
124-
)
125-
12683
def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[str, Any]:
12784
"""Create default diffusion stage configuration."""
12885
# TODO: here is different from the Omni class. We should merge the two in the future.
@@ -216,7 +173,7 @@ def _process_stage_ready(self, stage: OmniStage, stage_id: int, result: dict[str
216173
def _wait_for_stages_ready(self, timeout: int = 120) -> None:
217174
"""Wait for all stages to report readiness."""
218175
super()._wait_for_stages_ready(timeout)
219-
for stage in self.stage_list:
176+
for stage in self.resources.stage_list:
220177
if stage.vllm_config is not None and stage.tokenizer is not None:
221178
try:
222179
vllm_config = stage.vllm_config
@@ -305,11 +262,13 @@ async def generate(
305262
if sampling_params_list is None:
306263
sampling_params_list = self.default_sampling_params_list
307264

308-
if len(sampling_params_list) != len(self.stage_list):
309-
raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}")
265+
if len(sampling_params_list) != len(self.resources.stage_list):
266+
raise ValueError(
267+
f"Expected {len(self.resources.stage_list)} sampling params, got {len(sampling_params_list)}"
268+
)
310269

311270
# Orchestrator keeps stage objects for input derivation
312-
num_stages = len(self.stage_list)
271+
num_stages = len(self.resources.stage_list)
313272
# Track per-request start time for end-to-end timing
314273
_req_start_ts: dict[int, float] = {}
315274
_wall_start_ts: float = time.time()
@@ -318,7 +277,7 @@ async def generate(
318277
# Determine the final stage for E2E stats (highest stage_id with
319278
# final_output=True; fallback to last stage)
320279
final_stage_id_for_e2e = get_final_stage_id_for_e2e(
321-
output_modalities, self.output_modalities, self.stage_list
280+
output_modalities, self.output_modalities, self.resources.stage_list
322281
)
323282

324283
# Metrics/aggregation helper
@@ -337,7 +296,7 @@ async def generate(
337296
"engine_inputs": prompt,
338297
"sampling_params": sp0,
339298
}
340-
self.stage_list[0].submit(task)
299+
self.resources.stage_list[0].submit(task)
341300
metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time()
342301
_req_start_ts[request_id] = time.time()
343302
logger.info(
@@ -399,7 +358,7 @@ async def _process_async_results(
399358
all_stages_finished = {stage_id: False for stage_id in range(final_stage_id_for_e2e + 1)}
400359
submit_flag = True
401360
while not all(all_stages_finished.values()):
402-
for stage_id, stage in enumerate(self.stage_list[: final_stage_id_for_e2e + 1]):
361+
for stage_id, stage in enumerate(self.resources.stage_list[: final_stage_id_for_e2e + 1]):
403362
if all_stages_finished[stage_id]:
404363
continue
405364
try:
@@ -420,13 +379,13 @@ async def _process_async_results(
420379
next_prompt_len = max(1, compute_talker_prompt_ids_length(prompt_token_ids))
421380
engine_input["prompt_token_ids"] = [0] * next_prompt_len
422381
engine_input["multi_modal_data"] = engine_input["mm_processor_kwargs"] = None
423-
for i in range(1, len(self.stage_list)):
382+
for i in range(1, len(self.resources.stage_list)):
424383
task = {
425384
"request_id": request_id,
426385
"engine_inputs": engine_input,
427386
"sampling_params": sampling_params_list[i],
428387
}
429-
self.stage_list[i].submit(task)
388+
self.resources.stage_list[i].submit(task)
430389
metrics.stage_first_ts[i] = time.time()
431390
all_stages_finished[stage_id] = finished
432391

@@ -461,10 +420,10 @@ async def _process_sequential_results(
461420
# Forward to next stage if there is one
462421
next_stage_id = stage_id + 1
463422
if next_stage_id <= final_stage_id_for_e2e:
464-
next_stage: OmniStage = self.stage_list[next_stage_id]
423+
next_stage: OmniStage = self.resources.stage_list[next_stage_id]
465424
# Derive inputs for the next stage, record postprocess time
466425
with metrics.stage_postprocess_timer(stage_id, request_id):
467-
next_inputs = next_stage.process_engine_inputs(self.stage_list, prompt)
426+
next_inputs = next_stage.process_engine_inputs(self.resources.stage_list, prompt)
468427
sp_next: SamplingParams = sampling_params_list[next_stage_id]
469428

470429
# Check if we have a connector for this edge
@@ -481,7 +440,7 @@ async def _process_sequential_results(
481440
next_inputs=next_inputs,
482441
sampling_params=sp_next,
483442
original_prompt=prompt,
484-
next_stage_queue_submit_fn=self.stage_list[next_stage_id].submit,
443+
next_stage_queue_submit_fn=self.resources.stage_list[next_stage_id].submit,
485444
metrics=metrics,
486445
)
487446

@@ -574,10 +533,10 @@ def _process_single_result(
574533
return engine_outputs, finished, output_to_yield
575534

576535
def _run_output_handler(self) -> None:
577-
if self.output_handler is not None:
536+
if self.resources.output_handler is not None:
578537
return
579538

580-
stage_list = self.stage_list
539+
stage_list = self.resources.stage_list
581540
request_states = self.request_states
582541

583542
async def output_handler():
@@ -623,14 +582,14 @@ async def output_handler():
623582
else:
624583
await req_state.queue.put(error_msg)
625584
error_msg = {"request_id": req_state.request_id, "error": str(e)}
626-
self.output_handler = None # Make possible for restart
585+
self.resources.output_handler = None # Make possible for restart
627586

628-
self.output_handler = asyncio.create_task(output_handler())
587+
self.resources.output_handler = asyncio.create_task(output_handler())
629588

630589
@property
631590
def is_running(self) -> bool:
632591
# Is None before the loop is started.
633-
return len(self._stage_in_queues) > 0
592+
return len(self.resources._stage_in_queues) > 0
634593

635594
@property
636595
def is_stopped(self) -> bool:
@@ -654,20 +613,20 @@ def dead_error(self) -> BaseException:
654613

655614
async def abort(self, request_id: str | Iterable[str]) -> None:
656615
abort_task = {"type": OmniStageTaskType.ABORT, "request_id": request_id}
657-
for stage in self.stage_list:
616+
for stage in self.resources.stage_list:
658617
stage.submit(abort_task)
659618
return None
660619

661620
async def get_vllm_config(self) -> VllmConfig:
662-
for stage in self.stage_list:
621+
for stage in self.resources.stage_list:
663622
if stage.is_comprehension:
664623
# Use the vllm_config received from worker process
665624
if stage.vllm_config is not None:
666625
return stage.vllm_config
667626
return None
668627

669628
async def get_model_config(self) -> OmniModelConfig:
670-
for stage in self.stage_list:
629+
for stage in self.resources.stage_list:
671630
if stage.is_comprehension:
672631
# Use the vllm_config received from worker process
673632
if stage.vllm_config is not None:
@@ -678,13 +637,13 @@ async def get_input_preprocessor(self) -> InputPreprocessor:
678637
return None
679638

680639
async def get_tokenizer(self) -> TokenizerLike:
681-
for stage in self.stage_list:
640+
for stage in self.resources.stage_list:
682641
if stage.is_comprehension:
683642
return stage.tokenizer
684643
return None
685644

686645
async def is_tracing_enabled(self) -> bool:
687-
for stage in self.stage_list:
646+
for stage in self.resources.stage_list:
688647
if stage.is_comprehension:
689648
return stage.is_tracing_enabled
690649
return False

0 commit comments

Comments
 (0)