55import os
66
77import abc
8+ import contextlib
89import copy
910import json
1011import logging
@@ -389,6 +390,11 @@ def log_step(step_idx, display_every, iter_time, memcpyHtoD_time, dequeue_time):
389390 dataset , activate = self ._args .debug_performance
390391 )
391392
393+ @force_gpu_resync
394+ @tf .function ()
395+ def dequeue_batch (ds_iter ):
396+ return next (ds_iter )
397+
392398 @force_gpu_resync
393399 @tf .function ()
394400 def force_data_on_gpu (data , device = "/gpu:0" ):
@@ -408,53 +414,70 @@ def force_data_on_gpu(data, device="/gpu:0"):
408414 step_idx = 0
409415 ds_iter = iter (dataset )
410416
411- while True :
417+ if self ._args .tf_profile_export_path :
418+ profiling_ctx = tf .profiler .experimental .Profile (
419+ self ._args .tf_profile_export_path
420+ )
421+ tracing_ctx = tf .profiler .experimental .Trace
422+ else :
423+ profiling_ctx = contextlib .nullcontext ()
424+ tracing_ctx = lambda * a , ** kw : contextlib .nullcontext ()
412425
413- try :
414- start_time = time .time ()
415- data_batch = next (ds_iter )
416- dequeue_times .append (time .time () - start_time )
417- except :
418- break
419-
420- start_time = time .time ()
421- data_batch = force_data_on_gpu (data_batch )
422- memcopy_times .append (time .time () - start_time )
423-
424- x , y = self .preprocess_model_inputs (data_batch )
425-
426- start_time = time .time ()
427- y_pred = infer_batch (x )
428- iter_times .append (time .time () - start_time )
429-
430- if not self ._args .debug_performance :
431- log_step (
432- step_idx + 1 ,
433- display_every = self ._args .display_every ,
434- iter_time = np .mean (iter_times [- self ._args .display_every :]) * 1000 ,
435- memcpyHtoD_time = np .mean (memcopy_times [- self ._args .display_every :]) * 1000 ,
436- dequeue_time = np .mean (dequeue_times [- self ._args .display_every :]) * 1000
437- )
438- else :
439- print (f"{ 'GPU Iteration Time' :18s} : { iter_times [- 1 ]:08.4f} s" )
440- print (f"{ 'Data MemCopyHtoD Time' :18s} : { memcpyHtoD_time [- 1 ]:08.4f} s" )
441- print (f"{ 'Data Dequeue Time' :18s} : { dequeue_times [- 1 ]:08.4f} s" )
426+ with profiling_ctx :
427+
428+ while True :
442429
443- if not self ._args .use_synthetic_data :
444- data_aggregator .aggregate_data (y_pred , y )
430+ step_idx += 1
445431
446- if (self ._args .num_iterations is not None and
447- step_idx + 1 >= self ._args .num_iterations ):
448- break
432+ with tracing_ctx ('Inference Step' , step_num = step_idx , _r = 1 ):
449433
450- step_idx += 1
434+ with tracing_ctx ('Input Dequeueing' , step_num = step_idx , _r = 1 ):
435+ try :
436+ start_time = time .time ()
437+ data_batch = dequeue_batch (ds_iter )
438+ dequeue_times .append (time .time () - start_time )
439+ except :
440+ break
441+
442+ with tracing_ctx ('Inputs MemcpyHtoD' , step_num = step_idx , _r = 1 ):
443+ start_time = time .time ()
444+ data_batch = force_data_on_gpu (data_batch )
445+ memcopy_times .append (time .time () - start_time )
446+
447+ with tracing_ctx ('Inputs Preprocessing' , step_num = step_idx , _r = 1 ):
448+ x , y = self .preprocess_model_inputs (data_batch )
449+
450+ with tracing_ctx ('GPU Inference' , step_num = step_idx , _r = 1 ):
451+ start_time = time .time ()
452+ y_pred = infer_batch (x )
453+ iter_times .append (time .time () - start_time )
454+
455+ if not self ._args .debug_performance :
456+ log_step (
457+ step_idx ,
458+ display_every = self ._args .display_every ,
459+ iter_time = np .mean (iter_times [- self ._args .display_every :]) * 1000 ,
460+ memcpyHtoD_time = np .mean (memcopy_times [- self ._args .display_every :]) * 1000 ,
461+ dequeue_time = np .mean (dequeue_times [- self ._args .display_every :]) * 1000
462+ )
463+ else :
464+ print (f"{ 'GPU Iteration Time' :18s} : { iter_times [- 1 ]:08.4f} s" )
465+ print (f"{ 'Data MemCopyHtoD Time' :18s} : { memcpyHtoD_time [- 1 ]:08.4f} s" )
466+ print (f"{ 'Data Dequeue Time' :18s} : { dequeue_times [- 1 ]:08.4f} s" )
467+
468+ if not self ._args .use_synthetic_data :
469+ data_aggregator .aggregate_data (y_pred , y )
470+
471+ if (self ._args .num_iterations is not None and
472+ step_idx >= self ._args .num_iterations ):
473+ break
451474
452475 if (
453476 not self ._args .debug_performance and
454477 step_idx % self ._args .display_every != 0
455478 ): # avoids double printing
456479 log_step (
457- step_idx + 1 ,
480+ step_idx ,
458481 display_every = 1 , # force print
459482 iter_time = np .mean (iter_times [- self ._args .display_every :]) * 1000 ,
460483 memcpyHtoD_time = np .mean (memcopy_times [- self ._args .display_every :]) * 1000 ,
0 commit comments