Skip to content

Commit b950768

Browse files
committed
feat: add profiling for vllm-omni
Signed-off-by: Chen Yang <2082464740@qq.com>
1 parent 98ba3d9 commit b950768

File tree

1 file changed

+1
-40
lines changed

1 file changed

+1
-40
lines changed

vllm_omni/entrypoints/log_utils.py

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ def aggregate_rx_and_maybe_total(
192192
rx_bytes: float,
193193
rx_ms: float,
194194
in_flight_ms: float,
195-
extra_timings: dict[str, float | None] | None = None,
196195
) -> tuple[int, float, float] | None:
197196
try:
198197
# Update RX aggregates for (stage_id-1 -> stage_id)
@@ -209,9 +208,6 @@ def aggregate_rx_and_maybe_total(
209208
"rx_count": 0.0,
210209
"sum_total_ms": 0.0,
211210
"total_count": 0.0,
212-
"sum_vae_time_ms": 0.0,
213-
"sum_dit_time_ms": 0.0,
214-
"sum_denoise_time_ms": 0.0,
215211
}
216212
transfer_agg[key] = agg
217213
agg["sum_rx_bytes"] += float(rx_bytes)
@@ -228,13 +224,6 @@ def aggregate_rx_and_maybe_total(
228224
total_ms = tx_ms + float(in_flight_ms) + float(rx_ms)
229225
agg["sum_total_ms"] += total_ms
230226
agg["total_count"] += 1.0
231-
# Accumulate diffusion timings if provided
232-
if extra_timings:
233-
for k in ("vae_time_ms", "dit_time_ms", "denoise_time_ms"):
234-
v = extra_timings.get(k)
235-
if v is None:
236-
continue
237-
agg[f"sum_{k}"] = agg.get(f"sum_{k}", 0.0) + float(v)
238227
# accumulate per-request transfer totals
239228
try:
240229
pr = per_request.setdefault(rid_key, {"stages": {}, "transfers_ms": 0.0, "transfers_bytes": 0})
@@ -260,7 +249,6 @@ def record_sender_transfer_agg(
260249
try:
261250
key = (from_stage, to_stage)
262251
agg = transfer_agg.get(key)
263-
264252
if agg is None:
265253
agg = {
266254
"sum_bytes": 0.0,
@@ -271,24 +259,17 @@ def record_sender_transfer_agg(
271259
"rx_count": 0.0,
272260
"sum_total_ms": 0.0,
273261
"total_count": 0.0,
274-
"sum_vae_time_ms": 0.0,
275-
"sum_dit_time_ms": 0.0,
276-
"sum_denoise_time_ms": 0.0,
277262
}
278263
transfer_agg[key] = agg
279-
280-
# sender-side aggregation
281264
agg["sum_bytes"] += float(size_bytes)
282265
agg["sum_ms"] += float(tx_ms)
283266
agg["count"] += 1.0
284-
285267
# Store sender-side timing for per-request combination
286268
rid_key = str(req_id)
287269
transfer_edge_req[(from_stage, to_stage, rid_key)] = {
288270
"tx_ms": float(tx_ms),
289271
"size_bytes": float(size_bytes),
290272
}
291-
292273
except Exception:
293274
pass
294275

@@ -348,11 +329,6 @@ def build_transfer_summary(
348329
sum_total_ms = float(agg.get("sum_total_ms", 0.0))
349330
samples_total = int(agg.get("total_count", 0.0))
350331
total_mbps = (sum_bytes * 8.0) / (max(sum_total_ms, 1e-6) * 1000.0) if sum_bytes > 0 else 0.0
351-
sum_vae_ms = float(agg.get("sum_vae_time_ms", 0.0))
352-
sum_dit_ms = float(agg.get("sum_dit_time_ms", 0.0))
353-
sum_denoise_ms = float(agg.get("sum_denoise_time_ms", 0.0))
354-
def _avg(total: float) -> float:
355-
return total / max(samples_total, 1) if total > 0 else 0.0
356332
summary.append(
357333
{
358334
"from_stage": src,
@@ -368,12 +344,6 @@ def _avg(total: float) -> float:
368344
"total_samples": samples_total,
369345
"total_transfer_time_ms": sum_total_ms,
370346
"total_mbps": total_mbps,
371-
"diffusion_vae_time_ms": sum_vae_ms,
372-
"diffusion_vae_time_avg_ms": _avg(sum_vae_ms),
373-
"diffusion_dit_time_ms": sum_dit_ms,
374-
"diffusion_dit_time_avg_ms": _avg(sum_dit_ms),
375-
"diffusion_denoise_time_ms": sum_denoise_ms,
376-
"diffusion_denoise_time_avg_ms": _avg(sum_denoise_ms),
377347
}
378348
)
379349
return summary
@@ -395,11 +365,8 @@ class StageRequestMetrics:
395365
rx_decode_time_ms: float
396366
rx_transfer_bytes: int
397367
rx_in_flight_time_ms: float
398-
vae_time_ms: float | None = None
399-
dit_time_ms: float | None = None
400-
denoise_time_ms: float | None = None
401368

402-
stage_stats: StageStats | None = None
369+
stage_stats: StageStats
403370

404371

405372
class OrchestratorMetrics:
@@ -476,11 +443,6 @@ def on_stage_metrics(self, stage_id: int, req_id: Any, metrics: dict[str, Any])
476443
rx_b = float(metrics.get("rx_transfer_bytes", 0.0))
477444
rx_ms = float(metrics.get("rx_decode_time_ms", 0.0))
478445
in_flight_ms = float(metrics.get("rx_in_flight_time_ms", 0.0))
479-
extra_timings = {
480-
"vae_time_ms": metrics.get("vae_time_ms"),
481-
"dit_time_ms": metrics.get("dit_time_ms"),
482-
"denoise_time_ms": metrics.get("denoise_time_ms"),
483-
}
484446
combined = aggregate_rx_and_maybe_total(
485447
self.transfer_edge_req,
486448
self.transfer_agg,
@@ -490,7 +452,6 @@ def on_stage_metrics(self, stage_id: int, req_id: Any, metrics: dict[str, Any])
490452
rx_b,
491453
rx_ms,
492454
in_flight_ms,
493-
extra_timings,
494455
)
495456
if self.enable_stats and stage_id > 0:
496457
log_transfer_rx(

0 commit comments

Comments
 (0)