Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 47 additions & 9 deletions docs/contributing/profiling.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,28 @@

> **Warning:** Profiling incurs significant overhead. Use only for development and debugging, never in production.

vLLM-Omni uses the PyTorch Profiler to analyze performance across both **multi-stage omni-modality models** and **diffusion models**.
vLLM-Omni supports two profiling approaches:
- **PyTorch Profiler** — detailed CPU/CUDA traces (`*.pt.trace.json` files viewable in Perfetto)
- **Nsight Systems (nsys)** — GPU-level tracing with CUDA kernel timelines (`.nsys-rep` files)

### 1. Set the Output Directory
Before running any script, set this environment variable. The system detects this and automatically saves traces here.
### 1. Set the Output Directory (PyTorch Profiler)
Before running any profiling script, set this environment variable. The system detects this and automatically saves traces here.

```bash
export VLLM_TORCH_PROFILER_DIR=./profiles
```

### 2. Profiling Omni-Modality Models
### 2. Profiling Omni-Modality Models (Offline)

It is best to limit profiling to one iteration to keep trace files manageable.

```bash
export VLLM_PROFILER_MAX_ITERS=1
```
Optionally, skip initial warmup iterations before collecting traces:
```bash
export VLLM_PROFILER_DELAY_ITERS=1
```

**Selective Stage Profiling**
The profiler is default to function across all stages. But It is highly recommended to profile specific stages by passing the stages list, preventing from producing too large trace files:
Expand Down Expand Up @@ -82,7 +88,7 @@ omni_llm.close()
2. **Qwen3-Omni**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen3_omni/end2end.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen3_omni/end2end.py)


### 3. Profiling diffusion models
### 3. Profiling Diffusion Models (Offline)

Diffusion profiling is End-to-End, capturing encoding, denoising loops, and decoding.

Expand Down Expand Up @@ -131,15 +137,47 @@ python image_to_video.py \

2. **Wan-AI/Wan2.2-I2V-A14B-Diffusers**: [https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video](https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video)

> **Note:**
As of now, asynchronous (online) profiling is not fully supported in vLLM-Omni. While start_profile() and stop_profile() methods exist, they are only reliable in offline inference scripts (e.g., the provided end2end.py examples). Do not use them in server-mode or streaming scenarios—traces may be incomplete or fail to flush.
### 4. Nsight Systems Profiling (Diffusion)

For deeper GPU-level analysis of diffusion workloads, use NVIDIA Nsight Systems (`nsys`). Diffusion workers follow the same profiler pattern as vLLM — set `VLLM_TORCH_CUDA_PROFILE=1` to enable the CUDA profiler which signals nsys via `torch.cuda.profiler.start()/stop()`.

**Usage:**

```bash
# Enable CUDA profiler for nsys integration
export VLLM_TORCH_CUDA_PROFILE=1
# Capture a fixed range of iterations (skip warmup, then capture N iters)
export VLLM_PROFILER_DELAY_ITERS=10
export VLLM_PROFILER_MAX_ITERS=10
# Optional: enable NVTX ranges (used by vLLM tracing)
export VLLM_PROFILER_TRACE_DIR=./vllm_trace

nsys profile \
--capture-range=cudaProfilerApi \
--capture-range-end=stop \
--trace-fork-before-exec=true \
--cuda-graph-trace=node \
--sample=none \
--stats=true \
-o diffusion_trace \
python image_to_video.py --model Wan-AI/Wan2.2-I2V-A14B-Diffusers ...
```

The `VLLM_TORCH_CUDA_PROFILE=1` environment variable configures diffusion workers to use vLLM's `CudaProfilerWrapper`, which brackets GPU work with `torch.cuda.profiler.start()/stop()` calls that nsys captures.

```bash
ls diffusion_trace*.nsys-rep
nsys stats diffusion_trace.nsys-rep
```

Open the `.nsys-rep` file in the Nsight Systems GUI for detailed CUDA kernel timelines, memory operations, and NVTX ranges.

### 4. Analyzing Omni Traces
### 5. Analyzing Omni Traces

Output files are saved to your configured ```VLLM_TORCH_PROFILER_DIR```.

**Output**
**Chrome Trace** (```.json.gz```): Visual timeline of kernels and stages. Open in Perfetto UI.
**Chrome Trace** (```.pt.trace.json```): Visual timeline of kernels and stages. Open in Perfetto UI.

**Viewing Tools:**

Expand Down
120 changes: 49 additions & 71 deletions vllm_omni/diffusion/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,60 +196,46 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest):

def start_profile(self, trace_filename: str | None = None) -> None:
"""
Start torch profiling on all diffusion workers.
Start profiling on all diffusion workers.

Creates a directory (if needed) and sets up a base filename template
for per-rank profiler traces (typically saved as <template>_rank<N>.json).

Args:
trace_filename: Optional base filename (without extension or rank suffix).
If None, generates one using current timestamp.
Profiling is configured via vLLM's profiler config/environment variables:
- PyTorch profiler: VLLM_TORCH_PROFILER_DIR
- Nsight Systems (cuda profiler): VLLM_TORCH_CUDA_PROFILE=1
"""
if trace_filename is None:
trace_filename = f"stage_0_diffusion_{int(time.time())}_rank"

trace_dir = os.environ.get("VLLM_TORCH_PROFILER_DIR", "./profiles")

# Expand ~ and ~user, then make absolute (robust against cwd changes)
trace_dir = os.path.expanduser(trace_dir)
trace_dir = os.path.abspath(trace_dir)

try:
os.makedirs(trace_dir, exist_ok=True)
except OSError as exc:
logger.error(f"Failed to create profiler directory {trace_dir}: {exc}")
raise

# Build final template path (without rank or extension — torch.profiler appends those)
full_template = os.path.join(trace_dir, trace_filename)

expected_pattern = f"{full_template}*.json"
logger.info(f"Starting diffusion profiling → {expected_pattern}")
if trace_filename:
logger.debug(
"Diffusion profiling uses vLLM profiler config; trace_filename is ignored (%s).",
trace_filename,
)

# Also log the absolute directory once (useful in multi-node or containers)
logger.debug(f"Profiler output directory: {trace_dir}")
trace_dir = os.environ.get("VLLM_TORCH_PROFILER_DIR")
if trace_dir:
trace_dir = os.path.abspath(os.path.expanduser(trace_dir))
try:
os.makedirs(trace_dir, exist_ok=True)
except OSError as exc:
logger.error("Failed to create profiler directory %s: %s", trace_dir, exc)
raise
logger.info("Starting diffusion profiling. Torch traces will be written under %s", trace_dir)
else:
logger.info("Starting diffusion profiling.")

# Propagate to all workers
try:
self.collective_rpc(method="start_profile", args=(full_template,))
self.collective_rpc(method="start_profile")
except Exception as e:
logger.error("Failed to start profiling on workers", exc_info=True)
raise RuntimeError(f"Could not start profiler: {e}") from e

def stop_profile(self) -> dict:
"""
Stop profiling on all workers and collect the final trace/table paths.

The worker (torch_profiler.py) now handles trace export, compression to .gz,
and deletion of the original .json file. This method only collects and
reports the paths returned by the workers.
Stop profiling on all workers and best-effort collect any legacy outputs.

Returns:
dict with keys:
- "traces": list of final trace file paths (usually .json.gz)
- "tables": list of table strings (one per rank)
vLLM's profiler wrappers write traces directly to disk and do not return
per-rank file paths. This method preserves backward compatibility by
aggregating any dict-like results if present.
"""
logger.info("Stopping diffusion profiling and collecting results...")
logger.info("Stopping diffusion profiling...")

try:
# Give worker enough time — export + compression + table can be slow
Expand All @@ -262,54 +248,46 @@ def stop_profile(self) -> dict:
successful_traces = 0

if not results:
logger.warning("No profiling results returned from any rank")
logger.info("No profiling results returned from any rank.")
return output_files

for rank, res in enumerate(results):
if res is None:
# vLLM profiler wrappers return no per-rank payloads.
continue
if not isinstance(res, dict):
logger.warning(f"Rank {rank}: invalid result format (got {type(res)})")
logger.warning("Rank %s: invalid result format (got %s)", rank, type(res))
continue

# 1. Trace file — should be .json.gz if compression succeeded
trace_path = res.get("trace")
trace_path = res.get("trace") or res.get("traces")
if trace_path:
# We trust the worker — it created/compressed the file
logger.info(f"[Rank {rank}] Final trace: {trace_path}")
output_files["traces"].append(trace_path)
successful_traces += 1
if isinstance(trace_path, str):
output_files["traces"].append(trace_path)
elif isinstance(trace_path, list):
output_files["traces"].extend(trace_path)
successful_traces = len(output_files["traces"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since workers always return None right now, the entire for rank, res in enumerate(results) loop body is effectively dead code (the if res is None: continue skips everything). This will become useful after fixing the worker's stop_profile() to return the wrapper's result.


# Optional: warn if path looks suspicious (e.g. still .json)
if not trace_path.endswith((".json.gz", ".json")):
logger.warning(f"Rank {rank}: unusual trace path extension: {trace_path}")

# 2. Table file — plain text
table = res.get("table")
table = res.get("table") or res.get("tables")
if table:
output_files["tables"].append(table)
if isinstance(table, str):
output_files["tables"].append(table)
elif isinstance(table, list):
output_files["tables"].extend(table)

# Final summary logging
num_ranks = len(results)
if successful_traces > 0:
final_paths_str = ", ".join(output_files["traces"][:3])
if len(output_files["traces"]) > 3:
final_paths_str += f" ... (+{len(output_files['traces']) - 3} more)"

logger.info(
f"Profiling stopped. Collected {successful_traces} trace file(s) "
f"from {num_ranks} rank(s). "
f"Final trace paths: {final_paths_str}"
"Profiling stopped. Collected %s trace file(s) from %s rank(s).",
successful_traces,
len(results),
)
elif output_files["traces"]:
else:
logger.info(
f"Profiling stopped but no traces were successfully collected. "
f"Reported paths: {', '.join(output_files['traces'][:3])}"
f"{' ...' if len(output_files['traces']) > 3 else ''}"
"Profiling stopped. Traces are written by the active profiler "
"(PyTorch: VLLM_TORCH_PROFILER_DIR, nsys: -o output)."
)
else:
logger.info("Profiling stopped — no trace files were collected from any rank.")

if output_files["tables"]:
logger.debug(f"Collected {len(output_files['tables'])} profiling table(s)")
logger.debug("Collected %s profiling table(s)", len(output_files["tables"]))

return output_files

Expand Down
54 changes: 45 additions & 9 deletions vllm_omni/diffusion/worker/diffusion_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
import multiprocessing as mp
import os
from contextlib import AbstractContextManager, nullcontext
from typing import Any

import torch
import zmq
from vllm.config import VllmConfig
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.logger import init_logger
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
from vllm.utils.mem_utils import GiB_bytes

from vllm_omni.diffusion.data import (
Expand All @@ -29,7 +31,6 @@
)
from vllm_omni.diffusion.forward_context import set_forward_context
from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager
from vllm_omni.diffusion.profiler import CurrentProfiler
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner
from vllm_omni.lora.request import LoRARequest
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
self.model_runner: DiffusionModelRunner | None = None
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
self.lora_manager: DiffusionLoRAManager | None = None
self.profiler: Any | None = None
self.init_device()

def init_device(self) -> None:
Expand All @@ -89,6 +91,20 @@ def init_device(self) -> None:
vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size
self.vllm_config = vllm_config

# Initialize profiler based on profiler_config (follows vLLM pattern)
profiler_config = vllm_config.profiler_config
if profiler_config.profiler == "torch":
worker_name = f"diffusion-rank-{self.rank}"
self.profiler = TorchProfilerWrapper(
profiler_config,
worker_name=worker_name,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing activities parameter. vLLM's gpu_worker.py explicitly passes activities=["CPU", "CUDA"]:

self.profiler = TorchProfilerWrapper(
    profiler_config,
    worker_name=worker_name,
    local_rank=self.local_rank,
    activities=["CPU", "CUDA"],  # <-- add this
)

Without it, the torch profiler may not capture CUDA kernels, which defeats the purpose of nsys integration.

local_rank=self.local_rank,
)
elif profiler_config.profiler == "cuda":
self.profiler = CudaProfilerWrapper(profiler_config)
else:
self.profiler = None

# Initialize distributed environment
with set_forward_context(vllm_config=vllm_config, omni_diffusion_config=self.od_config):
init_distributed_environment(world_size=world_size, rank=rank)
Expand Down Expand Up @@ -129,15 +145,27 @@ def generate(self, request: OmniDiffusionRequest) -> DiffusionOutput:
"""Generate output for the given requests."""
return self.execute_model(request, self.od_config)

@classmethod
def start_profile(cls, trace_path_template: str) -> str:
"""Start profiling for this GPU worker."""
return CurrentProfiler.start(trace_path_template)
def start_profile(self, trace_path_template: str = "") -> str:
"""Start profiling for this GPU worker.

Uses vLLM's profiler wrappers based on profiler_config:
- 'torch': TorchProfilerWrapper for detailed CPU/CUDA traces
- 'cuda': CudaProfilerWrapper for nsys integration

@classmethod
def stop_profile(cls) -> dict | None:
Set VLLM_TORCH_CUDA_PROFILE=1 for nsys/cuda profiler, or
VLLM_TORCH_PROFILER_DIR for torch profiler.
"""
if self.profiler is not None:
self.profiler.start()
logger.info("Diffusion worker %s: profiler started", self.rank)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trace_path_template parameter is accepted but never used — vLLM's wrappers get their paths from profiler_config at init time. This is confusing for callers. Consider removing it entirely or at minimum documenting that it's ignored.

return trace_path_template

def stop_profile(self) -> dict | None:
"""Stop profiling and return the result dictionary."""
return CurrentProfiler.stop()
if self.profiler is not None:
self.profiler.stop()
logger.info("Diffusion worker %s: profiler stopped", self.rank)
return None

def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfig) -> DiffusionOutput:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stop_profile() always returns None, which means DiffusionEngine.stop_profile() never gets any trace paths from workers. The elaborate aggregation logic in the engine becomes dead code.

TorchProfilerWrapper.stop() returns a dict with trace file paths — please return that result instead of discarding it:

def stop_profile(self) -> dict | None:
    if self.profiler is not None:
        return self.profiler.stop()
    return None

"""Execute a forward pass by delegating to the model runner."""
Expand All @@ -149,7 +177,15 @@ def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfi
if req.sampling_params.lora_request is not None:
raise
logger.warning("LoRA activation skipped: %s", exc)
return self.model_runner.execute_model(req)
profiler_context = (
self.profiler.annotate_context_manager("diffusion_forward") if self.profiler is not None else nullcontext()
)
with profiler_context:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good use of annotate_context_manager and step() — this follows vLLM's pattern and gives clean trace segmentation per forward pass.

output = self.model_runner.execute_model(req)
if self.profiler is not None:
# Drive delayed start/auto-stop behavior to match vLLM's profiler wrapper.
self.profiler.step()
return output

def load_weights(self, weights) -> set[str]:
"""Load weights by delegating to the model runner."""
Expand Down
11 changes: 7 additions & 4 deletions vllm_omni/entrypoints/omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,10 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None:
def start_profile(self, stages: list[int] | None = None) -> None:
"""Start profiling for specified stages.

Sends start_profile command to stage workers. Profiling must be enabled
via VLLM_TORCH_PROFILER_DIR environment variable.
Sends start_profile command to stage workers. Profiling is configured
via vLLM profiler environment variables, e.g.:
- VLLM_TORCH_PROFILER_DIR for PyTorch profiler traces
- VLLM_TORCH_CUDA_PROFILE=1 for Nsight Systems (cuda profiler)

Args:
stages: List of stage IDs to start profiling. If None, starts
Expand Down Expand Up @@ -432,6 +434,9 @@ def stop_profile(self, stages: list[int] | None = None) -> dict:
# This is the blocking call that triggers the RPC chain
stage_data = stage.stop_profile()

if stage_data is None:
continue

if isinstance(stage_data, dict):
# FIX: Handle both single key and list key formats
traces = stage_data.get("trace") or stage_data.get("traces")
Expand All @@ -457,8 +462,6 @@ def stop_profile(self, stages: list[int] | None = None) -> dict:
all_results["tables"].append(tables)
elif isinstance(tables, list):
all_results["tables"].extend(tables)
else:
logger.warning(f"[{self._name}] Stage-{stage_id} returned no table data")
else:
logger.warning(f"[{self._name}] Stage-{stage_id} returned non-dict data: {type(stage_data)}")
else:
Expand Down
Loading