Skip to content

Commit 33e451f

Browse files
ahengljhclaude
andcommitted
[Profiler] Follow vLLM pattern for diffusion profiler integration
Use vLLM's CudaProfilerWrapper/TorchProfilerWrapper in DiffusionWorker instead of custom implementation. This unifies the profiler approach between omni models and diffusion models. - Import and use vLLM's profiler wrappers based on profiler_config - VLLM_TORCH_CUDA_PROFILE=1 enables CudaProfilerWrapper for nsys - VLLM_TORCH_PROFILER_DIR enables TorchProfilerWrapper for traces - Remove dependency on CurrentProfiler from diffusion profiler module - Update docs with vLLM-style nsys usage Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Jinheng Li <ahengljh@gmail.com>
1 parent 9c71ca7 commit 33e451f

File tree

2 files changed

+41
-26
lines changed

2 files changed

+41
-26
lines changed

docs/contributing/profiling.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,14 @@ python image_to_video.py \
135135

136136
### 4. Nsight Systems Profiling (Diffusion)
137137

138-
For deeper GPU-level analysis of diffusion workloads, use NVIDIA Nsight Systems (`nsys`). The diffusion worker integrates with nsys via `torch.cuda.profiler.start()/stop()` when profiling is triggered.
138+
For deeper GPU-level analysis of diffusion workloads, use NVIDIA Nsight Systems (`nsys`). Diffusion workers follow the same profiler pattern as vLLM — set `VLLM_TORCH_CUDA_PROFILE=1` to enable the CUDA profiler which signals nsys via `torch.cuda.profiler.start()/stop()`.
139139

140140
**Usage:**
141141

142142
```bash
143+
# Enable CUDA profiler for nsys integration
144+
export VLLM_TORCH_CUDA_PROFILE=1
145+
143146
nsys profile \
144147
--capture-range=cudaProfilerApi \
145148
--capture-range-end=repeat \
@@ -149,7 +152,7 @@ nsys profile \
149152
python image_to_video.py --model Wan-AI/Wan2.2-I2V-A14B-Diffusers ...
150153
```
151154

152-
Set `VLLM_TORCH_PROFILER_DIR` to trigger profiling, which also opens nsys capture regions in diffusion worker processes.
155+
The `VLLM_TORCH_CUDA_PROFILE=1` environment variable configures diffusion workers to use vLLM's `CudaProfilerWrapper`, which brackets GPU work with `torch.cuda.profiler.start()/stop()` calls that nsys captures.
153156

154157
```bash
155158
ls diffusion_trace*.nsys-rep

vllm_omni/diffusion/worker/diffusion_worker.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313

1414
import torch
1515
import zmq
16+
from typing import Any
17+
1618
from vllm.config import VllmConfig
1719
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
1820
from vllm.logger import init_logger
21+
from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
1922
from vllm.utils.mem_utils import GiB_bytes
2023

2124
from vllm_omni.diffusion.data import (
@@ -29,7 +32,6 @@
2932
)
3033
from vllm_omni.diffusion.forward_context import set_forward_context
3134
from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager
32-
from vllm_omni.diffusion.profiler import CurrentProfiler
3335
from vllm_omni.diffusion.request import OmniDiffusionRequest
3436
from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner
3537
from vllm_omni.lora.request import LoRARequest
@@ -65,6 +67,7 @@ def __init__(
6567
self.model_runner: DiffusionModelRunner | None = None
6668
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
6769
self.lora_manager: DiffusionLoRAManager | None = None
70+
self.profiler: Any | None = None
6871
self.init_device()
6972

7073
def init_device(self) -> None:
@@ -89,6 +92,21 @@ def init_device(self) -> None:
8992
vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size
9093
self.vllm_config = vllm_config
9194

95+
# Initialize profiler based on profiler_config (follows vLLM pattern)
96+
profiler_config = vllm_config.profiler_config
97+
if profiler_config.profiler == "torch":
98+
worker_name = f"diffusion-rank-{self.rank}"
99+
self.profiler = TorchProfilerWrapper(
100+
profiler_config,
101+
worker_name=worker_name,
102+
local_rank=self.local_rank,
103+
activities=["CPU", "CUDA"],
104+
)
105+
elif profiler_config.profiler == "cuda":
106+
self.profiler = CudaProfilerWrapper(profiler_config)
107+
else:
108+
self.profiler = None
109+
92110
# Initialize distributed environment
93111
with set_forward_context(vllm_config=vllm_config, omni_diffusion_config=self.od_config):
94112
init_distributed_environment(world_size=world_size, rank=rank)
@@ -129,33 +147,27 @@ def generate(self, request: OmniDiffusionRequest) -> DiffusionOutput:
129147
"""Generate output for the given requests."""
130148
return self.execute_model(request, self.od_config)
131149

132-
@classmethod
133-
def start_profile(cls, trace_path_template: str) -> str:
150+
def start_profile(self, trace_path_template: str = "") -> str:
134151
"""Start profiling for this GPU worker.
135152
136-
Also opens a CUDA profiler capture region so that nsys (when
137-
launched with ``--capture-range=cudaProfilerApi``) records GPU
138-
activity from within this worker process.
139-
"""
140-
if torch.cuda.is_available():
141-
try:
142-
torch.cuda.profiler.start()
143-
except Exception as e:
144-
logger.warning("Failed to start CUDA profiler in DiffusionWorker: %s", e)
145-
return CurrentProfiler.start(trace_path_template)
146-
147-
@classmethod
148-
def stop_profile(cls) -> dict | None:
149-
"""Stop profiling and return the result dictionary.
153+
Uses vLLM's profiler wrappers based on profiler_config:
154+
- 'torch': TorchProfilerWrapper for detailed CPU/CUDA traces
155+
- 'cuda': CudaProfilerWrapper for nsys integration
150156
151-
Also closes the CUDA profiler capture region for nsys.
157+
Set VLLM_TORCH_CUDA_PROFILE=1 for nsys/cuda profiler, or
158+
VLLM_TORCH_PROFILER_DIR for torch profiler.
152159
"""
153-
if torch.cuda.is_available():
154-
try:
155-
torch.cuda.profiler.stop()
156-
except Exception as e:
157-
logger.warning("Failed to stop CUDA profiler in DiffusionWorker: %s", e)
158-
return CurrentProfiler.stop()
160+
if self.profiler is not None:
161+
self.profiler.start()
162+
logger.info("Diffusion worker %s: profiler started", self.rank)
163+
return trace_path_template
164+
165+
def stop_profile(self) -> dict | None:
166+
"""Stop profiling and return the result dictionary."""
167+
if self.profiler is not None:
168+
self.profiler.stop()
169+
logger.info("Diffusion worker %s: profiler stopped", self.rank)
170+
return None
159171

160172
def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfig) -> DiffusionOutput:
161173
"""Execute a forward pass by delegating to the model runner."""

0 commit comments

Comments
 (0)