Skip to content

Commit b0c7853

Browse files
committed
Align diffusion profiling with vLLM
Signed-off-by: Jinheng Li <ahengljh@gmail.com>
1 parent 33e451f commit b0c7853

File tree

6 files changed

+98
-99
lines changed

6 files changed

+98
-99
lines changed

docs/contributing/profiling.md

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
> **Warning:** Profiling incurs significant overhead. Use only for development and debugging, never in production.
44
55
vLLM-Omni supports two profiling approaches:
6-
- **PyTorch Profiler** — detailed CPU/CUDA traces (`.json.gz` files viewable in Perfetto)
6+
- **PyTorch Profiler** — detailed CPU/CUDA traces (`*.pt.trace.json` files viewable in Perfetto)
77
- **Nsight Systems (nsys)** — GPU-level tracing with CUDA kernel timelines (`.nsys-rep` files)
88

99
### 1. Set the Output Directory (PyTorch Profiler)
@@ -20,6 +20,10 @@ It is best to limit profiling to one iteration to keep trace files manageable.
2020
```bash
2121
export VLLM_PROFILER_MAX_ITERS=1
2222
```
23+
Optionally, skip initial warmup iterations before collecting traces:
24+
```bash
25+
export VLLM_PROFILER_DELAY_ITERS=1
26+
```
2327

2428
**Selective Stage Profiling**
2529
The profiler is default to function across all stages. But It is highly recommended to profile specific stages by passing the stages list, preventing from producing too large trace files:
@@ -142,12 +146,19 @@ For deeper GPU-level analysis of diffusion workloads, use NVIDIA Nsight Systems
142146
```bash
143147
# Enable CUDA profiler for nsys integration
144148
export VLLM_TORCH_CUDA_PROFILE=1
149+
# Capture a fixed range of iterations (skip warmup, then capture N iters)
150+
export VLLM_PROFILER_DELAY_ITERS=10
151+
export VLLM_PROFILER_MAX_ITERS=10
152+
# Optional: enable NVTX ranges (used by vLLM tracing)
153+
export VLLM_PROFILER_TRACE_DIR=./vllm_trace
145154

146155
nsys profile \
147156
--capture-range=cudaProfilerApi \
148-
--capture-range-end=repeat \
157+
--capture-range-end=stop \
149158
--trace-fork-before-exec=true \
150159
--cuda-graph-trace=node \
160+
--sample=none \
161+
--stats=true \
151162
-o diffusion_trace \
152163
python image_to_video.py --model Wan-AI/Wan2.2-I2V-A14B-Diffusers ...
153164
```
@@ -166,7 +177,7 @@ Open the `.nsys-rep` file in the Nsight Systems GUI for detailed CUDA kernel tim
166177
Output files are saved to your configured ```VLLM_TORCH_PROFILER_DIR```.
167178

168179
**Output**
169-
**Chrome Trace** (```.json.gz```): Visual timeline of kernels and stages. Open in Perfetto UI.
180+
**Chrome Trace** (```.pt.trace.json```): Visual timeline of kernels and stages. Open in Perfetto UI.
170181

171182
**Viewing Tools:**
172183

vllm_omni/diffusion/diffusion_engine.py

Lines changed: 50 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -196,63 +196,49 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest):
196196

197197
def start_profile(self, trace_filename: str | None = None) -> None:
198198
"""
199-
Start torch profiling on all diffusion workers.
199+
Start profiling on all diffusion workers.
200200
201-
Creates a directory (if needed) and sets up a base filename template
202-
for per-rank profiler traces (typically saved as <template>_rank<N>.json).
203-
204-
Args:
205-
trace_filename: Optional base filename (without extension or rank suffix).
206-
If None, generates one using current timestamp.
201+
Profiling is configured via vLLM's profiler config/environment variables:
202+
- PyTorch profiler: VLLM_TORCH_PROFILER_DIR
203+
- Nsight Systems (cuda profiler): VLLM_TORCH_CUDA_PROFILE=1
207204
"""
208-
if trace_filename is None:
209-
trace_filename = f"stage_0_diffusion_{int(time.time())}_rank"
210-
211-
trace_dir = os.environ.get("VLLM_TORCH_PROFILER_DIR", "./profiles")
212-
213-
# Expand ~ and ~user, then make absolute (robust against cwd changes)
214-
trace_dir = os.path.expanduser(trace_dir)
215-
trace_dir = os.path.abspath(trace_dir)
216-
217-
try:
218-
os.makedirs(trace_dir, exist_ok=True)
219-
except OSError as exc:
220-
logger.error(f"Failed to create profiler directory {trace_dir}: {exc}")
221-
raise
222-
223-
# Build final template path (without rank or extension — torch.profiler appends those)
224-
full_template = os.path.join(trace_dir, trace_filename)
225-
226-
expected_pattern = f"{full_template}*.json"
227-
logger.info(f"Starting diffusion profiling → {expected_pattern}")
205+
if trace_filename:
206+
logger.debug(
207+
"Diffusion profiling uses vLLM profiler config; trace_filename is ignored (%s).",
208+
trace_filename,
209+
)
228210

229-
# Also log the absolute directory once (useful in multi-node or containers)
230-
logger.debug(f"Profiler output directory: {trace_dir}")
211+
trace_dir = os.environ.get("VLLM_TORCH_PROFILER_DIR")
212+
if trace_dir:
213+
trace_dir = os.path.abspath(os.path.expanduser(trace_dir))
214+
try:
215+
os.makedirs(trace_dir, exist_ok=True)
216+
except OSError as exc:
217+
logger.error("Failed to create profiler directory %s: %s", trace_dir, exc)
218+
raise
219+
logger.info("Starting diffusion profiling. Torch traces will be written under %s", trace_dir)
220+
else:
221+
logger.info("Starting diffusion profiling.")
231222

232223
# Propagate to all workers
233224
try:
234-
self.collective_rpc(method="start_profile", args=(full_template,))
225+
self.collective_rpc(method="start_profile")
235226
except Exception as e:
236227
logger.error("Failed to start profiling on workers", exc_info=True)
237228
raise RuntimeError(f"Could not start profiler: {e}") from e
238229

239230
def stop_profile(self) -> dict:
240231
"""
241-
Stop profiling on all workers and collect the final trace/table paths.
242-
243-
The worker (torch_profiler.py) now handles trace export, compression to .gz,
244-
and deletion of the original .json file. This method only collects and
245-
reports the paths returned by the workers.
232+
Stop profiling on all workers and best-effort collect any legacy outputs.
246233
247-
Returns:
248-
dict with keys:
249-
- "traces": list of final trace file paths (usually .json.gz)
250-
- "tables": list of table strings (one per rank)
234+
vLLM's profiler wrappers write traces directly to disk and do not return
235+
per-rank file paths. This method preserves backward compatibility by
236+
aggregating any dict-like results if present.
251237
"""
252-
logger.info("Stopping diffusion profiling and collecting results...")
238+
logger.info("Stopping diffusion profiling...")
253239

254240
try:
255-
# Give worker enough time — export + compression + table can be slow
241+
# Give workers enough time — trace flushing can be slow
256242
results = self.collective_rpc(method="stop_profile", timeout=60000)
257243
except Exception:
258244
logger.error("Failed to stop profiling on workers", exc_info=True)
@@ -262,54 +248,46 @@ def stop_profile(self) -> dict:
262248
successful_traces = 0
263249

264250
if not results:
265-
logger.warning("No profiling results returned from any rank")
251+
logger.info("No profiling results returned from any rank.")
266252
return output_files
267253

268254
for rank, res in enumerate(results):
255+
if res is None:
256+
# vLLM profiler wrappers return no per-rank payloads.
257+
continue
269258
if not isinstance(res, dict):
270-
logger.warning(f"Rank {rank}: invalid result format (got {type(res)})")
259+
logger.warning("Rank %s: invalid result format (got %s)", rank, type(res))
271260
continue
272261

273-
# 1. Trace file — should be .json.gz if compression succeeded
274-
trace_path = res.get("trace")
262+
trace_path = res.get("trace") or res.get("traces")
275263
if trace_path:
276-
# We trust the worker — it created/compressed the file
277-
logger.info(f"[Rank {rank}] Final trace: {trace_path}")
278-
output_files["traces"].append(trace_path)
279-
successful_traces += 1
264+
if isinstance(trace_path, str):
265+
output_files["traces"].append(trace_path)
266+
elif isinstance(trace_path, list):
267+
output_files["traces"].extend(trace_path)
268+
successful_traces = len(output_files["traces"])
280269

281-
# Optional: warn if path looks suspicious (e.g. still .json)
282-
if not trace_path.endswith((".json.gz", ".json")):
283-
logger.warning(f"Rank {rank}: unusual trace path extension: {trace_path}")
284-
285-
# 2. Table file — plain text
286-
table = res.get("table")
270+
table = res.get("table") or res.get("tables")
287271
if table:
288-
output_files["tables"].append(table)
272+
if isinstance(table, str):
273+
output_files["tables"].append(table)
274+
elif isinstance(table, list):
275+
output_files["tables"].extend(table)
289276

290-
# Final summary logging
291-
num_ranks = len(results)
292277
if successful_traces > 0:
293-
final_paths_str = ", ".join(output_files["traces"][:3])
294-
if len(output_files["traces"]) > 3:
295-
final_paths_str += f" ... (+{len(output_files['traces']) - 3} more)"
296-
297278
logger.info(
298-
f"Profiling stopped. Collected {successful_traces} trace file(s) "
299-
f"from {num_ranks} rank(s). "
300-
f"Final trace paths: {final_paths_str}"
279+
"Profiling stopped. Collected %s trace file(s) from %s rank(s).",
280+
successful_traces,
281+
len(results),
301282
)
302-
elif output_files["traces"]:
283+
else:
303284
logger.info(
304-
f"Profiling stopped but no traces were successfully collected. "
305-
f"Reported paths: {', '.join(output_files['traces'][:3])}"
306-
f"{' ...' if len(output_files['traces']) > 3 else ''}"
285+
"Profiling stopped. Traces are written by the active profiler "
286+
"(PyTorch: VLLM_TORCH_PROFILER_DIR, nsys: -o output)."
307287
)
308-
else:
309-
logger.info("Profiling stopped — no trace files were collected from any rank.")
310288

311289
if output_files["tables"]:
312-
logger.debug(f"Collected {len(output_files['tables'])} profiling table(s)")
290+
logger.debug("Collected %s profiling table(s)", len(output_files["tables"]))
313291

314292
return output_files
315293

vllm_omni/diffusion/worker/diffusion_worker.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def init_device(self) -> None:
100100
profiler_config,
101101
worker_name=worker_name,
102102
local_rank=self.local_rank,
103-
activities=["CPU", "CUDA"],
104103
)
105104
elif profiler_config.profiler == "cuda":
106105
self.profiler = CudaProfilerWrapper(profiler_config)
@@ -179,7 +178,15 @@ def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfi
179178
if req.sampling_params.lora_request is not None:
180179
raise
181180
logger.warning("LoRA activation skipped: %s", exc)
182-
return self.model_runner.execute_model(req)
181+
profiler_context = (
182+
self.profiler.annotate_context_manager("diffusion_forward") if self.profiler is not None else nullcontext()
183+
)
184+
with profiler_context:
185+
output = self.model_runner.execute_model(req)
186+
if self.profiler is not None:
187+
# Drive delayed start/auto-stop behavior to match vLLM's profiler wrapper.
188+
self.profiler.step()
189+
return output
183190

184191
def load_weights(self, weights) -> set[str]:
185192
"""Load weights by delegating to the model runner."""

vllm_omni/entrypoints/omni.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,10 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None:
379379
def start_profile(self, stages: list[int] | None = None) -> None:
380380
"""Start profiling for specified stages.
381381
382-
Sends start_profile command to stage workers. Profiling must be enabled
383-
via VLLM_TORCH_PROFILER_DIR environment variable.
382+
Sends start_profile command to stage workers. Profiling is configured
383+
via vLLM profiler environment variables, e.g.:
384+
- VLLM_TORCH_PROFILER_DIR for PyTorch profiler traces
385+
- VLLM_TORCH_CUDA_PROFILE=1 for Nsight Systems (cuda profiler)
384386
385387
Args:
386388
stages: List of stage IDs to start profiling. If None, starts
@@ -432,6 +434,9 @@ def stop_profile(self, stages: list[int] | None = None) -> dict:
432434
# This is the blocking call that triggers the RPC chain
433435
stage_data = stage.stop_profile()
434436

437+
if stage_data is None:
438+
continue
439+
435440
if isinstance(stage_data, dict):
436441
# FIX: Handle both single key and list key formats
437442
traces = stage_data.get("trace") or stage_data.get("traces")
@@ -457,8 +462,6 @@ def stop_profile(self, stages: list[int] | None = None) -> dict:
457462
all_results["tables"].append(tables)
458463
elif isinstance(tables, list):
459464
all_results["tables"].extend(tables)
460-
else:
461-
logger.warning(f"[{self._name}] Stage-{stage_id} returned no table data")
462465
else:
463466
logger.warning(f"[{self._name}] Stage-{stage_id} returned non-dict data: {type(stage_data)}")
464467
else:

vllm_omni/entrypoints/omni_diffusion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def start_profile(self, trace_filename: str | None = None) -> None:
120120
121121
Args:
122122
trace_filename: Optional base filename for trace files.
123-
If None, a timestamp-based name will be generated.
123+
Note: vLLM profiler wrappers ignore this value and write traces
124+
under VLLM_TORCH_PROFILER_DIR instead.
124125
"""
125126
if hasattr(self, "engine") and self.engine:
126127
self.engine.start_profile(trace_filename)

vllm_omni/entrypoints/omni_stage.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -735,11 +735,11 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict:
735735
if task_type == OmniStageTaskType.PROFILER_START:
736736
if stage_type == "diffusion":
737737
try:
738-
profile_dir = _os.environ.get("VLLM_TORCH_PROFILER_DIR", "./profiles")
739-
_os.makedirs(profile_dir, exist_ok=True)
740-
trace_filename = f"stage_{stage_id}_diffusion_{int(_time.time())}"
741-
stage_engine.start_profile(trace_filename=trace_filename)
742-
logger.info("[Stage-%s] Diffusion Torch profiler started", stage_id)
738+
profile_dir = _os.environ.get("VLLM_TORCH_PROFILER_DIR")
739+
if profile_dir:
740+
_os.makedirs(profile_dir, exist_ok=True)
741+
stage_engine.start_profile()
742+
logger.info("[Stage-%s] Diffusion profiler started", stage_id)
743743
except Exception as e:
744744
logger.warning("[Stage-%s] Failed to start diffusion profiler: %s", stage_id, e)
745745
else:
@@ -753,10 +753,9 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict:
753753
elif task_type == OmniStageTaskType.PROFILER_STOP:
754754
if stage_type == "diffusion":
755755
try:
756-
# CRITICAL: Capture return value
757756
result_data = stage_engine.stop_profile()
758-
logger.info("[Stage-%s] Diffusion Torch profiler stopped", stage_id)
759-
return result_data
757+
logger.info("[Stage-%s] Diffusion profiler stopped", stage_id)
758+
return result_data if isinstance(result_data, dict) else {}
760759
except Exception as e:
761760
logger.warning("[Stage-%s] Failed to stop diffusion profiler: %s", stage_id, e)
762761
return {}
@@ -1289,11 +1288,11 @@ async def handle_profiler_task_async(task_type: OmniStageTaskType) -> None:
12891288
if stage_type == "diffusion":
12901289
try:
12911290
# Sync call is safe here — diffusion profiling is lightweight
1292-
profile_dir = os.environ.get("VLLM_TORCH_PROFILER_DIR", "./profiles")
1293-
os.makedirs(profile_dir, exist_ok=True)
1294-
trace_filename = f"stage_{stage_id}_diffusion_{int(time.time())}"
1295-
stage_engine.start_profile(trace_filename=trace_filename)
1296-
logger.info("[Stage-%s] Diffusion Torch profiler started", stage_id)
1291+
profile_dir = os.environ.get("VLLM_TORCH_PROFILER_DIR")
1292+
if profile_dir:
1293+
os.makedirs(profile_dir, exist_ok=True)
1294+
stage_engine.start_profile()
1295+
logger.info("[Stage-%s] Diffusion profiler started", stage_id)
12971296
except Exception as e:
12981297
logger.warning("[Stage-%s] Failed to start diffusion profiler: %s", stage_id, e)
12991298
else:
@@ -1306,10 +1305,10 @@ async def handle_profiler_task_async(task_type: OmniStageTaskType) -> None:
13061305
elif task_type == OmniStageTaskType.PROFILER_STOP:
13071306
if stage_type == "diffusion":
13081307
try:
1309-
trace_files = stage_engine.stop_profile()
1310-
logger.info("[Stage-%s] Diffusion Torch profiler stopped", stage_id)
1311-
if trace_files:
1312-
logger.info("Diffusion trace files: %s", trace_files)
1308+
result_data = stage_engine.stop_profile()
1309+
logger.info("[Stage-%s] Diffusion profiler stopped", stage_id)
1310+
if result_data:
1311+
logger.info("Diffusion profiler result: %s", result_data)
13131312
except Exception as e:
13141313
logger.warning("[Stage-%s] Failed to stop diffusion profiler: %s", stage_id, e)
13151314
else:

0 commit comments

Comments
 (0)