diff --git a/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py b/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py index 8ad5cbf9bc..ee45999b1d 100644 --- a/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py +++ b/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py @@ -99,6 +99,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--vae_use_slicing", action="store_true") parser.add_argument("--vae_use_tiling", action="store_true") parser.add_argument("--enable-cpu-offload", action="store_true") + parser.add_argument("--log-stats", "--log_stats", dest="log_stats", action="store_true", default=False) return parser.parse_args() @@ -155,6 +156,7 @@ async def main(): cache_config=cache_config, parallel_config=parallel_config, enforce_eager=args.enforce_eager, + log_stats=args.log_stats, enable_cpu_offload=args.enable_cpu_offload, diffusion_load_format="dummy", custom_pipeline_args={"pipeline_class": "custom_pipeline.CustomPipeline"}, @@ -259,4 +261,4 @@ async def main(): # Entrypoint # =========================== if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(main()) \ No newline at end of file diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index 81fe24e218..305a6d54ab 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -320,6 +320,11 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable layerwise (blockwise) offloading on DiT modules.", ) + parser.add_argument( + "--log-stats", + action="store_true", + help="Enable vLLM-Omni statistics logging.", + ) return parser.parse_args() @@ -383,6 +388,7 @@ def main(): parallel_config=parallel_config, enforce_eager=args.enforce_eager, enable_cpu_offload=args.enable_cpu_offload, + log_stats=args.log_stats, ) print("Pipeline loaded") diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py index a491ddfc33..38500990d0 100644 --- a/examples/offline_inference/image_to_video/image_to_video.py +++ b/examples/offline_inference/image_to_video/image_to_video.py @@ -133,6 +133,11 @@ def parse_args() -> argparse.Namespace: "Default 1 means pure sharding (no replication). " ), ) + parser.add_argument( + "--log-stats", + action="store_true", + help="Enable vLLM-Omni statistics logging.", + ) return parser.parse_args() @@ -187,6 +192,7 @@ def main(): enable_cpu_offload=args.enable_cpu_offload, parallel_config=parallel_config, enforce_eager=args.enforce_eager, + log_stats=args.log_stats, ) if profiler_enabled: diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 3829716068..15cbbbf72a 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -94,6 +94,11 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable cache-dit summary logging after diffusion forward passes.", ) + parser.add_argument( + "--log-stats", + action="store_true", + help="Enable vLLM-Omni statistics logging.", + ) parser.add_argument( "--ulysses-degree", type=int, @@ -299,6 +304,7 @@ def main(): "parallel_config": parallel_config, "enforce_eager": args.enforce_eager, "enable_cpu_offload": args.enable_cpu_offload, + "log_stats": args.log_stats, **lora_args, **quant_kwargs, } diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index 40fafa1009..77098998b5 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -109,6 +109,11 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of GPUs used for tensor parallelism (TP) inside the DiT.", ) + parser.add_argument( + "--log-stats", + action="store_true", + help="Enable vLLM-Omni statistics logging.", + ) return parser.parse_args() @@ -159,6 +164,7 @@ def main(): enable_cpu_offload=args.enable_cpu_offload, parallel_config=parallel_config, enforce_eager=args.enforce_eager, + log_stats=args.log_stats, ) if profiler_enabled: diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 8e4a9f7a20..bc2b262971 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -67,14 +67,19 @@ def __init__(self, od_config: OmniDiffusionConfig): raise e def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: + diffusion_engine_start_time = time.time() # Apply pre-processing if available + preprocess_time = 0.0 if self.pre_process_func is not None: preprocess_start_time = time.time() request = self.pre_process_func(request) preprocess_time = time.time() - preprocess_start_time logger.info(f"Pre-processing completed in {preprocess_time:.4f} seconds") + exec_start_time = time.time() output = self.add_req_and_wait_for_response(request) + exec_total_time = time.time() - exec_start_time + if output.error: raise Exception(f"{output.error}") logger.info("Generation completed successfully.") @@ -103,14 +108,20 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: outputs = [outputs] if outputs is not None else [] metrics = { + "preprocess_time_ms": round(preprocess_time * 1000, 2), + "diffusion_engine_exec_time_ms": round((time.time() - diffusion_engine_start_time) * 1000, 2), + "executor_time_ms": round(exec_total_time * 1000, 2), "image_num": int(request.sampling_params.num_outputs_per_prompt), "resolution": int(request.sampling_params.resolution), - "postprocess_time_ms": postprocess_time * 1000, } + if self.pre_process_func is not None: - metrics["preprocessing_time_ms"] = preprocess_time * 1000 + metrics["preprocessing_time_ms"] = round(preprocess_time * 1000, 2) # Handle single request or multiple requests + metrics["postprocess_time_ms"] = round(postprocess_time * 1000, 2) + metrics["num_inference_steps"] = int(request.sampling_params.num_inference_steps) + if len(request.prompts) == 1: # Single request: return single OmniRequestOutput prompt = request.prompts[0] @@ -177,7 +188,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: ) ) - return results + return results @staticmethod def make_engine(config: OmniDiffusionConfig) -> "DiffusionEngine": diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 58dd88eb09..d1bf162640 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -1131,6 +1131,15 @@ def _run_generation( final_output_type=stage.final_output_type, # type: ignore[attr-defined] request_output=engine_outputs, ) + try: + if stage.final_output_type == "text" or metrics.log_stats: + output_to_yield.metrics = metrics.build_output_metrics(stage_id, req_id) + except Exception as e: + # Make metrics contract explicit on failure. + output_to_yield.metrics = {} + logger.exception( + f"[{self._name}] Failed to attach output metrics for req {req_id} at stage {stage_id}: {e}", + ) # Record audio generated frames (only when finished) try: diff --git a/vllm_omni/metrics/stats.py b/vllm_omni/metrics/stats.py index 46fec6fcc1..068b4a30b8 100644 --- a/vllm_omni/metrics/stats.py +++ b/vllm_omni/metrics/stats.py @@ -39,7 +39,7 @@ class StageRequestStats: final_output_type: str | None = None request_id: str | None = None postprocess_time_ms: float = 0.0 - diffusion_metrics: dict[str, int] = None + diffusion_metrics: dict[str, float] = None audio_generated_frames: int = 0 @property @@ -293,23 +293,15 @@ def process_stage_metrics( return # 4. Finished with output: assign text metrics if available - output_to_yield.metrics = {} - stage_event = next( - (evt for evt in reversed(self.stage_events.get(req_id, [])) if evt.stage_id == stage_id), - None, - ) - if stage_event is not None and stage_event.final_output_type == "text": - output_to_yield.metrics = { - "num_tokens_in": stage_event.num_tokens_in, - "num_tokens_out": stage_event.num_tokens_out, - "stage_id": stage_event.stage_id, - "final_output_type": stage_event.final_output_type, - } + output_to_yield.metrics = self.build_output_metrics(stage_id, req_id) # 5. Finished: record audio generated frames self.record_audio_generated_frames(output_to_yield, stage_id, req_id) except Exception: + if output_to_yield is not None: + # Make metrics contract explicit on failure. + output_to_yield.metrics = {} logger.exception( "Failed to process metrics for stage %s, req %s", stage_id, @@ -329,12 +321,56 @@ def _as_stage_request_stats( stats.request_id = req_id stats.final_output_type = final_output_type stats.diffusion_metrics = ( - {k: int(v) for k, v in self.diffusion_metrics.pop(req_id, {}).items()} + {k: float(v) for k, v in self.diffusion_metrics.pop(req_id, {}).items()} if req_id in self.diffusion_metrics else None ) return stats + def _get_stage_event(self, stage_id: int, req_id: Any) -> StageRequestStats | None: + rid_key = str(req_id) + for evt in reversed(self.stage_events.get(rid_key, [])): + if evt.stage_id == stage_id: + return evt + return None + + def _collect_diffusion_metrics(self, req_id: Any) -> dict[str, float]: + """Aggregate diffusion metrics across all stages for a request.""" + rid_key = str(req_id) + merged: dict[str, float] = {} + for evt in self.stage_events.get(rid_key, []): + if not evt.diffusion_metrics: + continue + for key, value in evt.diffusion_metrics.items(): + merged[key] = merged.get(key, 0.0) + float(value) + return merged + + def build_output_metrics(self, stage_id: int, req_id: Any) -> dict[str, Any]: + stage_event = self._get_stage_event(stage_id, req_id) + if stage_event is None: + return {} + + merged: dict[str, Any] = {} + + if stage_event.final_output_type == "text": + merged.update( + { + "num_tokens_in": stage_event.num_tokens_in, + "num_tokens_out": stage_event.num_tokens_out, + } + ) + + if self.log_stats: + diffusion_metrics = self._collect_diffusion_metrics(req_id) + if diffusion_metrics: + merged.update(diffusion_metrics) + + if merged: + merged["stage_id"] = stage_event.stage_id + merged["final_output_type"] = stage_event.final_output_type + + return merged + def on_stage_metrics( self, stage_id: int, diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py index 5bbd27a1db..904671885c 100644 --- a/vllm_omni/outputs.py +++ b/vllm_omni/outputs.py @@ -237,19 +237,60 @@ def __repr__(self) -> str: """Custom repr to properly show image count instead of image objects.""" # For images, show count instead of full list images_repr = f"[{len(self.images)} PIL Images]" if self.images else "[]" - # Build repr string + + def _repr_nested(obj) -> str: + if isinstance(obj, list): + return "[" + ", ".join(_repr_nested(x) for x in obj) + "]" + if isinstance(obj, OmniRequestOutput): + return obj._repr_multiline(indent=" ") + return repr(obj) + parts = [ f"request_id={self.request_id!r}", f"finished={self.finished}", f"stage_id={self.stage_id}", f"final_output_type={self.final_output_type!r}", - f"request_output={self.request_output}", + f"request_output={_repr_nested(self.request_output)}", f"images={images_repr}", f"prompt={self.prompt!r}", f"latents={self.latents}", f"metrics={self.metrics}", f"multimodal_output={self._multimodal_output}", ] - return f"OmniRequestOutput({', '.join(parts)})" + + def _repr_multiline(self, indent: str = "") -> str: + """Helper to produce multi-line, indented repr for nested logging.""" + images_repr = f"[{len(self.images)} PIL Images]" if self.images else "[]" + + def _repr_nested(obj, ind: str) -> str: + if isinstance(obj, list): + inner = ",\n".join(_repr_nested(x, ind + " ") for x in obj) + return "[\n" + inner + "\n" + ind + "]" + if isinstance(obj, OmniRequestOutput): + return obj._repr_multiline(indent=ind + " ") + return repr(obj) + + # Format metrics with each key-value pair on a separate line + if self.metrics: + metrics_indent = indent + " " + metrics_lines = f",\n{metrics_indent}".join(f"{k!r}: {v!r}" for k, v in self.metrics.items()) + metrics_repr = f"{{\n{metrics_indent}{metrics_lines}\n{indent} }}" + else: + metrics_repr = "{}" + + lines = [ + f"{indent}OmniRequestOutput(", + f"{indent} request_id={self.request_id!r},", + f"{indent} finished={self.finished},", + f"{indent} stage_id={self.stage_id},", + f"{indent} final_output_type={self.final_output_type!r},", + f"{indent} request_output={_repr_nested(self.request_output, indent + ' ')},", + f"{indent} images={images_repr},", + f"{indent} prompt={self.prompt!r},", + f"{indent} latents={self.latents},", + f"{indent} metrics={metrics_repr},", + f"{indent})", + ] + return "\n".join(lines)