33import asyncio
44import copy
55import time
6- import weakref
76from collections .abc import AsyncGenerator , Iterable , Sequence
87from typing import Any
98
1817from vllm_omni .config import OmniModelConfig
1918from vllm_omni .diffusion .data import DiffusionParallelConfig
2019from vllm_omni .distributed .omni_connectors .adapter import compute_talker_prompt_ids_length , try_send_via_connector
21- from vllm_omni .distributed .ray_utils .utils import try_close_ray
2220from vllm_omni .engine .input_processor import OmniInputProcessor
2321from vllm_omni .entrypoints .client_request_state import ClientRequestState
2422from vllm_omni .entrypoints .omni import OmniBase
2523from vllm_omni .entrypoints .omni_stage import OmniStage
26- from vllm_omni .entrypoints .stage_utils import SHUTDOWN_TASK , OmniStageTaskType
24+ from vllm_omni .entrypoints .stage_utils import OmniStageTaskType
2725from vllm_omni .entrypoints .stage_utils import maybe_load_from_ipc as _load
2826from vllm_omni .entrypoints .utils import (
2927 get_final_stage_id_for_e2e ,
3836logger = init_logger (__name__ )
3937
4038
41- def _weak_close_cleanup_async (stage_list , stage_in_queues , stage_out_queues , ray_pg , output_handler , zmq_ctx = None ):
42- """Weak reference cleanup function for AsyncOmni instances."""
43- if stage_list :
44- for q in stage_in_queues :
45- try :
46- q .put_nowait (SHUTDOWN_TASK )
47- except Exception as e :
48- logger .warning (f"Failed to send shutdown signal to stage input queue: { e } " )
49- close_fn = getattr (q , "close" , None )
50- if callable (close_fn ):
51- close_fn ()
52- for q in stage_out_queues :
53- close_fn = getattr (q , "close" , None )
54- if callable (close_fn ):
55- close_fn ()
56- for stage in stage_list :
57- try :
58- stage .stop_stage_worker ()
59- except Exception as e :
60- logger .warning (f"Failed to stop stage worker: { e } " )
61- try_close_ray (ray_pg )
62- # Cancel output handler
63- if output_handler is not None :
64- output_handler .cancel ()
65- if zmq_ctx is not None :
66- zmq_ctx .term ()
67-
68-
6939class AsyncOmni (OmniBase ):
7040 """Asynchronous unified entry point supporting multi-stage pipelines for LLM and Diffusion models.
7141
@@ -107,22 +77,9 @@ def __init__(self, model: str, **kwargs: dict[str, Any]) -> None:
10777
10878 # Request state tracking
10979 self .request_states : dict [str , ClientRequestState ] = {}
110- self .output_handler : asyncio .Task | None = None
11180
11281 super ().__init__ (model , ** kwargs )
11382
114- # Register weak reference cleanup (called on garbage collection)
115- self ._weak_finalizer = weakref .finalize (
116- self ,
117- _weak_close_cleanup_async ,
118- self .stage_list ,
119- self ._stage_in_queues ,
120- self ._stage_out_queues ,
121- self ._ray_pg ,
122- self .output_handler ,
123- self ._zmq_ctx ,
124- )
125-
12683 def _create_default_diffusion_stage_cfg (self , kwargs : dict [str , Any ]) -> dict [str , Any ]:
12784 """Create default diffusion stage configuration."""
12885 # TODO: here is different from the Omni class. We should merge the two in the future.
@@ -216,7 +173,7 @@ def _process_stage_ready(self, stage: OmniStage, stage_id: int, result: dict[str
216173 def _wait_for_stages_ready (self , timeout : int = 120 ) -> None :
217174 """Wait for all stages to report readiness."""
218175 super ()._wait_for_stages_ready (timeout )
219- for stage in self .stage_list :
176+ for stage in self .resources . stage_list :
220177 if stage .vllm_config is not None and stage .tokenizer is not None :
221178 try :
222179 vllm_config = stage .vllm_config
@@ -305,11 +262,13 @@ async def generate(
305262 if sampling_params_list is None :
306263 sampling_params_list = self .default_sampling_params_list
307264
308- if len (sampling_params_list ) != len (self .stage_list ):
309- raise ValueError (f"Expected { len (self .stage_list )} sampling params, got { len (sampling_params_list )} " )
265+ if len (sampling_params_list ) != len (self .resources .stage_list ):
266+ raise ValueError (
267+ f"Expected { len (self .resources .stage_list )} sampling params, got { len (sampling_params_list )} "
268+ )
310269
311270 # Orchestrator keeps stage objects for input derivation
312- num_stages = len (self .stage_list )
271+ num_stages = len (self .resources . stage_list )
313272 # Track per-request start time for end-to-end timing
314273 _req_start_ts : dict [int , float ] = {}
315274 _wall_start_ts : float = time .time ()
@@ -318,7 +277,7 @@ async def generate(
318277 # Determine the final stage for E2E stats (highest stage_id with
319278 # final_output=True; fallback to last stage)
320279 final_stage_id_for_e2e = get_final_stage_id_for_e2e (
321- output_modalities , self .output_modalities , self .stage_list
280+ output_modalities , self .output_modalities , self .resources . stage_list
322281 )
323282
324283 # Metrics/aggregation helper
@@ -337,7 +296,7 @@ async def generate(
337296 "engine_inputs" : prompt ,
338297 "sampling_params" : sp0 ,
339298 }
340- self .stage_list [0 ].submit (task )
299+ self .resources . stage_list [0 ].submit (task )
341300 metrics .stage_first_ts [0 ] = metrics .stage_first_ts [0 ] or time .time ()
342301 _req_start_ts [request_id ] = time .time ()
343302 logger .info (
@@ -399,7 +358,7 @@ async def _process_async_results(
399358 all_stages_finished = {stage_id : False for stage_id in range (final_stage_id_for_e2e + 1 )}
400359 submit_flag = True
401360 while not all (all_stages_finished .values ()):
402- for stage_id , stage in enumerate (self .stage_list [: final_stage_id_for_e2e + 1 ]):
361+ for stage_id , stage in enumerate (self .resources . stage_list [: final_stage_id_for_e2e + 1 ]):
403362 if all_stages_finished [stage_id ]:
404363 continue
405364 try :
@@ -420,13 +379,13 @@ async def _process_async_results(
420379 next_prompt_len = max (1 , compute_talker_prompt_ids_length (prompt_token_ids ))
421380 engine_input ["prompt_token_ids" ] = [0 ] * next_prompt_len
422381 engine_input ["multi_modal_data" ] = engine_input ["mm_processor_kwargs" ] = None
423- for i in range (1 , len (self .stage_list )):
382+ for i in range (1 , len (self .resources . stage_list )):
424383 task = {
425384 "request_id" : request_id ,
426385 "engine_inputs" : engine_input ,
427386 "sampling_params" : sampling_params_list [i ],
428387 }
429- self .stage_list [i ].submit (task )
388+ self .resources . stage_list [i ].submit (task )
430389 metrics .stage_first_ts [i ] = time .time ()
431390 all_stages_finished [stage_id ] = finished
432391
@@ -461,10 +420,10 @@ async def _process_sequential_results(
461420 # Forward to next stage if there is one
462421 next_stage_id = stage_id + 1
463422 if next_stage_id <= final_stage_id_for_e2e :
464- next_stage : OmniStage = self .stage_list [next_stage_id ]
423+ next_stage : OmniStage = self .resources . stage_list [next_stage_id ]
465424 # Derive inputs for the next stage, record postprocess time
466425 with metrics .stage_postprocess_timer (stage_id , request_id ):
467- next_inputs = next_stage .process_engine_inputs (self .stage_list , prompt )
426+ next_inputs = next_stage .process_engine_inputs (self .resources . stage_list , prompt )
468427 sp_next : SamplingParams = sampling_params_list [next_stage_id ]
469428
470429 # Check if we have a connector for this edge
@@ -481,7 +440,7 @@ async def _process_sequential_results(
481440 next_inputs = next_inputs ,
482441 sampling_params = sp_next ,
483442 original_prompt = prompt ,
484- next_stage_queue_submit_fn = self .stage_list [next_stage_id ].submit ,
443+ next_stage_queue_submit_fn = self .resources . stage_list [next_stage_id ].submit ,
485444 metrics = metrics ,
486445 )
487446
@@ -574,10 +533,10 @@ def _process_single_result(
574533 return engine_outputs , finished , output_to_yield
575534
576535 def _run_output_handler (self ) -> None :
577- if self .output_handler is not None :
536+ if self .resources . output_handler is not None :
578537 return
579538
580- stage_list = self .stage_list
539+ stage_list = self .resources . stage_list
581540 request_states = self .request_states
582541
583542 async def output_handler ():
@@ -623,14 +582,14 @@ async def output_handler():
623582 else :
624583 await req_state .queue .put (error_msg )
625584 error_msg = {"request_id" : req_state .request_id , "error" : str (e )}
626- self .output_handler = None # Make possible for restart
585+ self .resources . output_handler = None # Make possible for restart
627586
628- self .output_handler = asyncio .create_task (output_handler ())
587+ self .resources . output_handler = asyncio .create_task (output_handler ())
629588
630589 @property
631590 def is_running (self ) -> bool :
632591 # Is None before the loop is started.
633- return len (self ._stage_in_queues ) > 0
592+ return len (self .resources . _stage_in_queues ) > 0
634593
635594 @property
636595 def is_stopped (self ) -> bool :
@@ -654,20 +613,20 @@ def dead_error(self) -> BaseException:
654613
655614 async def abort (self , request_id : str | Iterable [str ]) -> None :
656615 abort_task = {"type" : OmniStageTaskType .ABORT , "request_id" : request_id }
657- for stage in self .stage_list :
616+ for stage in self .resources . stage_list :
658617 stage .submit (abort_task )
659618 return None
660619
661620 async def get_vllm_config (self ) -> VllmConfig :
662- for stage in self .stage_list :
621+ for stage in self .resources . stage_list :
663622 if stage .is_comprehension :
664623 # Use the vllm_config received from worker process
665624 if stage .vllm_config is not None :
666625 return stage .vllm_config
667626 return None
668627
669628 async def get_model_config (self ) -> OmniModelConfig :
670- for stage in self .stage_list :
629+ for stage in self .resources . stage_list :
671630 if stage .is_comprehension :
672631 # Use the vllm_config received from worker process
673632 if stage .vllm_config is not None :
@@ -678,13 +637,13 @@ async def get_input_preprocessor(self) -> InputPreprocessor:
678637 return None
679638
680639 async def get_tokenizer (self ) -> TokenizerLike :
681- for stage in self .stage_list :
640+ for stage in self .resources . stage_list :
682641 if stage .is_comprehension :
683642 return stage .tokenizer
684643 return None
685644
686645 async def is_tracing_enabled (self ) -> bool :
687- for stage in self .stage_list :
646+ for stage in self .resources . stage_list :
688647 if stage .is_comprehension :
689648 return stage .is_tracing_enabled
690649 return False
0 commit comments