@@ -192,7 +192,6 @@ def aggregate_rx_and_maybe_total(
192192 rx_bytes : float ,
193193 rx_ms : float ,
194194 in_flight_ms : float ,
195- extra_timings : dict [str , float | None ] | None = None ,
196195) -> tuple [int , float , float ] | None :
197196 try :
198197 # Update RX aggregates for (stage_id-1 -> stage_id)
@@ -209,9 +208,6 @@ def aggregate_rx_and_maybe_total(
209208 "rx_count" : 0.0 ,
210209 "sum_total_ms" : 0.0 ,
211210 "total_count" : 0.0 ,
212- "sum_vae_time_ms" : 0.0 ,
213- "sum_dit_time_ms" : 0.0 ,
214- "sum_denoise_time_ms" : 0.0 ,
215211 }
216212 transfer_agg [key ] = agg
217213 agg ["sum_rx_bytes" ] += float (rx_bytes )
@@ -228,13 +224,6 @@ def aggregate_rx_and_maybe_total(
228224 total_ms = tx_ms + float (in_flight_ms ) + float (rx_ms )
229225 agg ["sum_total_ms" ] += total_ms
230226 agg ["total_count" ] += 1.0
231- # Accumulate diffusion timings if provided
232- if extra_timings :
233- for k in ("vae_time_ms" , "dit_time_ms" , "denoise_time_ms" ):
234- v = extra_timings .get (k )
235- if v is None :
236- continue
237- agg [f"sum_{ k } " ] = agg .get (f"sum_{ k } " , 0.0 ) + float (v )
238227 # accumulate per-request transfer totals
239228 try :
240229 pr = per_request .setdefault (rid_key , {"stages" : {}, "transfers_ms" : 0.0 , "transfers_bytes" : 0 })
@@ -260,7 +249,6 @@ def record_sender_transfer_agg(
260249 try :
261250 key = (from_stage , to_stage )
262251 agg = transfer_agg .get (key )
263-
264252 if agg is None :
265253 agg = {
266254 "sum_bytes" : 0.0 ,
@@ -271,24 +259,17 @@ def record_sender_transfer_agg(
271259 "rx_count" : 0.0 ,
272260 "sum_total_ms" : 0.0 ,
273261 "total_count" : 0.0 ,
274- "sum_vae_time_ms" : 0.0 ,
275- "sum_dit_time_ms" : 0.0 ,
276- "sum_denoise_time_ms" : 0.0 ,
277262 }
278263 transfer_agg [key ] = agg
279-
280- # sender-side aggregation
281264 agg ["sum_bytes" ] += float (size_bytes )
282265 agg ["sum_ms" ] += float (tx_ms )
283266 agg ["count" ] += 1.0
284-
285267 # Store sender-side timing for per-request combination
286268 rid_key = str (req_id )
287269 transfer_edge_req [(from_stage , to_stage , rid_key )] = {
288270 "tx_ms" : float (tx_ms ),
289271 "size_bytes" : float (size_bytes ),
290272 }
291-
292273 except Exception :
293274 pass
294275
@@ -348,11 +329,6 @@ def build_transfer_summary(
348329 sum_total_ms = float (agg .get ("sum_total_ms" , 0.0 ))
349330 samples_total = int (agg .get ("total_count" , 0.0 ))
350331 total_mbps = (sum_bytes * 8.0 ) / (max (sum_total_ms , 1e-6 ) * 1000.0 ) if sum_bytes > 0 else 0.0
351- sum_vae_ms = float (agg .get ("sum_vae_time_ms" , 0.0 ))
352- sum_dit_ms = float (agg .get ("sum_dit_time_ms" , 0.0 ))
353- sum_denoise_ms = float (agg .get ("sum_denoise_time_ms" , 0.0 ))
354- def _avg (total : float ) -> float :
355- return total / max (samples_total , 1 ) if total > 0 else 0.0
356332 summary .append (
357333 {
358334 "from_stage" : src ,
@@ -368,12 +344,6 @@ def _avg(total: float) -> float:
368344 "total_samples" : samples_total ,
369345 "total_transfer_time_ms" : sum_total_ms ,
370346 "total_mbps" : total_mbps ,
371- "diffusion_vae_time_ms" : sum_vae_ms ,
372- "diffusion_vae_time_avg_ms" : _avg (sum_vae_ms ),
373- "diffusion_dit_time_ms" : sum_dit_ms ,
374- "diffusion_dit_time_avg_ms" : _avg (sum_dit_ms ),
375- "diffusion_denoise_time_ms" : sum_denoise_ms ,
376- "diffusion_denoise_time_avg_ms" : _avg (sum_denoise_ms ),
377347 }
378348 )
379349 return summary
@@ -395,11 +365,8 @@ class StageRequestMetrics:
395365 rx_decode_time_ms : float
396366 rx_transfer_bytes : int
397367 rx_in_flight_time_ms : float
398- vae_time_ms : float | None = None
399- dit_time_ms : float | None = None
400- denoise_time_ms : float | None = None
401368
402- stage_stats : StageStats | None = None
369+ stage_stats : StageStats
403370
404371
405372class OrchestratorMetrics :
@@ -476,11 +443,6 @@ def on_stage_metrics(self, stage_id: int, req_id: Any, metrics: dict[str, Any])
476443 rx_b = float (metrics .get ("rx_transfer_bytes" , 0.0 ))
477444 rx_ms = float (metrics .get ("rx_decode_time_ms" , 0.0 ))
478445 in_flight_ms = float (metrics .get ("rx_in_flight_time_ms" , 0.0 ))
479- extra_timings = {
480- "vae_time_ms" : metrics .get ("vae_time_ms" ),
481- "dit_time_ms" : metrics .get ("dit_time_ms" ),
482- "denoise_time_ms" : metrics .get ("denoise_time_ms" ),
483- }
484446 combined = aggregate_rx_and_maybe_total (
485447 self .transfer_edge_req ,
486448 self .transfer_agg ,
@@ -490,7 +452,6 @@ def on_stage_metrics(self, stage_id: int, req_id: Any, metrics: dict[str, Any])
490452 rx_b ,
491453 rx_ms ,
492454 in_flight_ms ,
493- extra_timings ,
494455 )
495456 if self .enable_stats and stage_id > 0 :
496457 log_transfer_rx (
0 commit comments