Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--vae_use_slicing", action="store_true")
parser.add_argument("--vae_use_tiling", action="store_true")
parser.add_argument("--enable-cpu-offload", action="store_true")
parser.add_argument("--log-stats", "--log_stats", dest="log_stats", action="store_true", default=False)

return parser.parse_args()

Expand Down Expand Up @@ -155,6 +156,7 @@ async def main():
cache_config=cache_config,
parallel_config=parallel_config,
enforce_eager=args.enforce_eager,
log_stats=args.log_stats,
enable_cpu_offload=args.enable_cpu_offload,
diffusion_load_format="dummy",
custom_pipeline_args={"pipeline_class": "custom_pipeline.CustomPipeline"},
Expand Down Expand Up @@ -259,4 +261,4 @@ async def main():
# Entrypoint
# ===========================
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())
6 changes: 6 additions & 0 deletions examples/offline_inference/image_to_image/image_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,11 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable layerwise (blockwise) offloading on DiT modules.",
)
parser.add_argument(
"--log-stats",
action="store_true",
help="Enable vLLM-Omni statistics logging.",
)
return parser.parse_args()


Expand Down Expand Up @@ -383,6 +388,7 @@ def main():
parallel_config=parallel_config,
enforce_eager=args.enforce_eager,
enable_cpu_offload=args.enable_cpu_offload,
log_stats=args.log_stats,
)
print("Pipeline loaded")

Expand Down
6 changes: 6 additions & 0 deletions examples/offline_inference/image_to_video/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ def parse_args() -> argparse.Namespace:
"Default 1 means pure sharding (no replication). "
),
)
parser.add_argument(
"--log-stats",
action="store_true",
help="Enable vLLM-Omni statistics logging.",
)
return parser.parse_args()


Expand Down Expand Up @@ -187,6 +192,7 @@ def main():
enable_cpu_offload=args.enable_cpu_offload,
parallel_config=parallel_config,
enforce_eager=args.enforce_eager,
log_stats=args.log_stats,
)

if profiler_enabled:
Expand Down
6 changes: 6 additions & 0 deletions examples/offline_inference/text_to_image/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable cache-dit summary logging after diffusion forward passes.",
)
parser.add_argument(
"--log-stats",
action="store_true",
help="Enable vLLM-Omni statistics logging.",
)
parser.add_argument(
"--ulysses-degree",
type=int,
Expand Down Expand Up @@ -299,6 +304,7 @@ def main():
"parallel_config": parallel_config,
"enforce_eager": args.enforce_eager,
"enable_cpu_offload": args.enable_cpu_offload,
"log_stats": args.log_stats,
**lora_args,
**quant_kwargs,
}
Expand Down
6 changes: 6 additions & 0 deletions examples/offline_inference/text_to_video/text_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ def parse_args() -> argparse.Namespace:
default=1,
help="Number of GPUs used for tensor parallelism (TP) inside the DiT.",
)
parser.add_argument(
"--log-stats",
action="store_true",
help="Enable vLLM-Omni statistics logging.",
)
return parser.parse_args()


Expand Down Expand Up @@ -159,6 +164,7 @@ def main():
enable_cpu_offload=args.enable_cpu_offload,
parallel_config=parallel_config,
enforce_eager=args.enforce_eager,
log_stats=args.log_stats,
)

if profiler_enabled:
Expand Down
17 changes: 14 additions & 3 deletions vllm_omni/diffusion/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,19 @@ def __init__(self, od_config: OmniDiffusionConfig):
raise e

def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
diffusion_engine_start_time = time.time()
# Apply pre-processing if available
preprocess_time = 0.0
if self.pre_process_func is not None:
preprocess_start_time = time.time()
request = self.pre_process_func(request)
preprocess_time = time.time() - preprocess_start_time
logger.info(f"Pre-processing completed in {preprocess_time:.4f} seconds")

exec_start_time = time.time()
output = self.add_req_and_wait_for_response(request)
exec_total_time = time.time() - exec_start_time

if output.error:
raise Exception(f"{output.error}")
logger.info("Generation completed successfully.")
Expand Down Expand Up @@ -103,14 +108,20 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
outputs = [outputs] if outputs is not None else []

metrics = {
"preprocess_time_ms": round(preprocess_time * 1000, 2),
"diffusion_engine_exec_time_ms": round((time.time() - diffusion_engine_start_time) * 1000, 2),
"executor_time_ms": round(exec_total_time * 1000, 2),
"image_num": int(request.sampling_params.num_outputs_per_prompt),
"resolution": int(request.sampling_params.resolution),
"postprocess_time_ms": postprocess_time * 1000,
}

if self.pre_process_func is not None:
metrics["preprocessing_time_ms"] = preprocess_time * 1000
metrics["preprocessing_time_ms"] = round(preprocess_time * 1000, 2)

# Handle single request or multiple requests
metrics["postprocess_time_ms"] = round(postprocess_time * 1000, 2)
metrics["num_inference_steps"] = int(request.sampling_params.num_inference_steps)

if len(request.prompts) == 1:
# Single request: return single OmniRequestOutput
prompt = request.prompts[0]
Expand Down Expand Up @@ -177,7 +188,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
)
)

return results
return results
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This indentation change moves return results out of the else branch -- was this intentional? Double-check the single-prompt case still returns correctly.


@staticmethod
def make_engine(config: OmniDiffusionConfig) -> "DiffusionEngine":
Expand Down
9 changes: 9 additions & 0 deletions vllm_omni/entrypoints/omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,15 @@ def _run_generation(
final_output_type=stage.final_output_type, # type: ignore[attr-defined]
request_output=engine_outputs,
)
try:
if stage.final_output_type == "text" or metrics.log_stats:
output_to_yield.metrics = metrics.build_output_metrics(stage_id, req_id)
except Exception as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bare except Exception that silently swallows any bug in build_output_metrics. This makes metric issues very hard to debug. Why not let it propagate, or at minimum set output_to_yield.metrics = {} in the except block so the contract is explicit?

# Make metrics contract explicit on failure.
output_to_yield.metrics = {}
logger.exception(
f"[{self._name}] Failed to attach output metrics for req {req_id} at stage {stage_id}: {e}",
)

# Record audio generated frames (only when finished)
try:
Expand Down
64 changes: 50 additions & 14 deletions vllm_omni/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class StageRequestStats:
final_output_type: str | None = None
request_id: str | None = None
postprocess_time_ms: float = 0.0
diffusion_metrics: dict[str, int] = None
diffusion_metrics: dict[str, float] = None
audio_generated_frames: int = 0

@property
Expand Down Expand Up @@ -293,23 +293,15 @@ def process_stage_metrics(
return

# 4. Finished with output: assign text metrics if available
output_to_yield.metrics = {}
stage_event = next(
(evt for evt in reversed(self.stage_events.get(req_id, [])) if evt.stage_id == stage_id),
None,
)
if stage_event is not None and stage_event.final_output_type == "text":
output_to_yield.metrics = {
"num_tokens_in": stage_event.num_tokens_in,
"num_tokens_out": stage_event.num_tokens_out,
"stage_id": stage_event.stage_id,
"final_output_type": stage_event.final_output_type,
}
output_to_yield.metrics = self.build_output_metrics(stage_id, req_id)

# 5. Finished: record audio generated frames
self.record_audio_generated_frames(output_to_yield, stage_id, req_id)

except Exception:
if output_to_yield is not None:
# Make metrics contract explicit on failure.
output_to_yield.metrics = {}
logger.exception(
"Failed to process metrics for stage %s, req %s",
stage_id,
Expand All @@ -329,12 +321,56 @@ def _as_stage_request_stats(
stats.request_id = req_id
stats.final_output_type = final_output_type
stats.diffusion_metrics = (
{k: int(v) for k, v in self.diffusion_metrics.pop(req_id, {}).items()}
{k: float(v) for k, v in self.diffusion_metrics.pop(req_id, {}).items()}
if req_id in self.diffusion_metrics
else None
)
return stats

def _get_stage_event(self, stage_id: int, req_id: Any) -> StageRequestStats | None:
rid_key = str(req_id)
for evt in reversed(self.stage_events.get(rid_key, [])):
if evt.stage_id == stage_id:
return evt
return None

def _collect_diffusion_metrics(self, req_id: Any) -> dict[str, float]:
"""Aggregate diffusion metrics across all stages for a request."""
rid_key = str(req_id)
merged: dict[str, float] = {}
for evt in self.stage_events.get(rid_key, []):
if not evt.diffusion_metrics:
continue
for key, value in evt.diffusion_metrics.items():
merged[key] = merged.get(key, 0.0) + float(value)
return merged

def build_output_metrics(self, stage_id: int, req_id: Any) -> dict[str, Any]:
stage_event = self._get_stage_event(stage_id, req_id)
if stage_event is None:
return {}

merged: dict[str, Any] = {}

if stage_event.final_output_type == "text":
merged.update(
{
"num_tokens_in": stage_event.num_tokens_in,
"num_tokens_out": stage_event.num_tokens_out,
}
)

if self.log_stats:
diffusion_metrics = self._collect_diffusion_metrics(req_id)
if diffusion_metrics:
merged.update(diffusion_metrics)

if merged:
merged["stage_id"] = stage_event.stage_id
merged["final_output_type"] = stage_event.final_output_type

return merged

def on_stage_metrics(
self,
stage_id: int,
Expand Down
47 changes: 44 additions & 3 deletions vllm_omni/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,19 +237,60 @@ def __repr__(self) -> str:
"""Custom repr to properly show image count instead of image objects."""
# For images, show count instead of full list
images_repr = f"[{len(self.images)} PIL Images]" if self.images else "[]"

# Build repr string

def _repr_nested(obj) -> str:
if isinstance(obj, list):
return "[" + ", ".join(_repr_nested(x) for x in obj) + "]"
if isinstance(obj, OmniRequestOutput):
return obj._repr_multiline(indent=" ")
return repr(obj)

parts = [
f"request_id={self.request_id!r}",
f"finished={self.finished}",
f"stage_id={self.stage_id}",
f"final_output_type={self.final_output_type!r}",
f"request_output={self.request_output}",
f"request_output={_repr_nested(self.request_output)}",
f"images={images_repr}",
f"prompt={self.prompt!r}",
f"latents={self.latents}",
f"metrics={self.metrics}",
f"multimodal_output={self._multimodal_output}",
]

return f"OmniRequestOutput({', '.join(parts)})"

def _repr_multiline(self, indent: str = "") -> str:
"""Helper to produce multi-line, indented repr for nested logging."""
images_repr = f"[{len(self.images)} PIL Images]" if self.images else "[]"

def _repr_nested(obj, ind: str) -> str:
if isinstance(obj, list):
inner = ",\n".join(_repr_nested(x, ind + " ") for x in obj)
return "[\n" + inner + "\n" + ind + "]"
if isinstance(obj, OmniRequestOutput):
return obj._repr_multiline(indent=ind + " ")
return repr(obj)

# Format metrics with each key-value pair on a separate line
if self.metrics:
metrics_indent = indent + " "
metrics_lines = f",\n{metrics_indent}".join(f"{k!r}: {v!r}" for k, v in self.metrics.items())
metrics_repr = f"{{\n{metrics_indent}{metrics_lines}\n{indent} }}"
else:
metrics_repr = "{}"

lines = [
f"{indent}OmniRequestOutput(",
f"{indent} request_id={self.request_id!r},",
f"{indent} finished={self.finished},",
f"{indent} stage_id={self.stage_id},",
f"{indent} final_output_type={self.final_output_type!r},",
f"{indent} request_output={_repr_nested(self.request_output, indent + ' ')},",
f"{indent} images={images_repr},",
f"{indent} prompt={self.prompt!r},",
f"{indent} latents={self.latents},",
f"{indent} metrics={metrics_repr},",
f"{indent})",
]
return "\n".join(lines)
Loading