99import logging
1010import time
1111
12+ from collections import defaultdict
1213from contextlib import contextmanager
1314from functools import partial
15+ from operator import itemgetter
1416
1517import numpy as np
1618import tensorflow as tf
@@ -64,11 +66,11 @@ def before_benchmark(self, **kwargs):
6466 pass
6567
6668 @abc .abstractmethod
67- def compute_accuracy_metric (self , batch_size , steps_executed , ** kwargs ):
69+ def compute_accuracy_metric (self , predictions , expected , ** kwargs ):
6870 raise NotImplementedError ()
6971
7072 @abc .abstractmethod
71- def process_model_output (self , outputs , batch_y , ** kwargs ):
73+ def process_model_output (self , outputs , ** kwargs ):
7274 raise NotImplementedError ()
7375
7476 ############################################################################
@@ -81,21 +83,26 @@ def __init__(
8183 output_saved_model_dir ,
8284 allow_build_at_runtime = False ,
8385 calibration_input_fn = None ,
86+ debug = False ,
8487 gpu_mem_cap = None ,
8588 input_signature_key = DEFAULT_SERVING_SIGNATURE_DEF_KEY ,
8689 max_workspace_size_bytes = DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES ,
8790 minimum_segment_size = 5 ,
8891 num_calib_inputs = None ,
8992 optimize_offline = False ,
9093 optimize_offline_input_fn = None ,
94+ output_tensor_indices = None ,
95+ output_tensor_names = None ,
9196 precision_mode = None ,
9297 use_dynamic_shape = False ,
93- use_tftrt = False
98+ use_tftrt = False ,
9499 ):
95100
96101 logging .getLogger ("tensorflow" ).setLevel (logging .INFO )
97102 logging .disable (logging .WARNING )
98103
104+ self ._debug = debug
105+
99106 # TensorFlow can execute operations synchronously or asynchronously.
100107 # If asynchronous execution is enabled, operations may return
101108 # "non-ready" handles.
@@ -131,15 +138,17 @@ def __init__(
131138 use_tftrt = use_tftrt
132139 )
133140
141+ self ._set_output_tensor_name (output_tensor_indices , output_tensor_names )
142+
134143 def _config_gpu_memory (self , gpu_mem_cap ):
135144 gpus = tf .config .experimental .list_physical_devices ('GPU' )
136145
137146 if not gpus :
138147 raise RuntimeError ("No GPUs has been found." )
139148
140- print ('Found the following GPUs:' )
149+ self . debug_print ('Found the following GPUs:' )
141150 for gpu in gpus :
142- print ( ' ' , gpu )
151+ self . debug_print ( f" \t - { gpu } " )
143152
144153 for gpu in gpus :
145154 try :
@@ -153,6 +162,42 @@ def _config_gpu_memory(self, gpu_mem_cap):
153162 except RuntimeError as e :
154163 print ('Can not set GPU memory config' , e )
155164
165+ def _set_output_tensor_name (
166+ self , output_tensor_indices , output_tensor_names
167+ ):
168+ structured_outputs = self ._graph_func .structured_outputs
169+
170+ if isinstance (structured_outputs , (list , tuple )):
171+ if output_tensor_indices is None :
172+ output_tensor_indices = list (range (len (structured_outputs )))
173+ else :
174+ output_tensor_indices = [
175+ int (i ) for i in output_tensor_indices .split ("," )
176+ ]
177+
178+ self ._output_tensors = output_tensor_indices
179+
180+ elif isinstance (structured_outputs , dict ):
181+ structured_outputs = dict (sorted (structured_outputs .items ()))
182+ if output_tensor_names is None :
183+ output_tensor_names = list (structured_outputs .keys ())
184+ else :
185+ output_tensor_names = [n for n in output_tensor_names .split ("," )]
186+ for name in output_tensor_names :
187+ if name not in structured_outputs .keys ():
188+ raise ValueError (
189+ f"Unknown output_tensor_names received: { name } . " \
190+ f"Authorized: { structured_outputs .keys ()} " )
191+
192+ self ._output_tensors = output_tensor_names
193+
194+ else :
195+ raise RuntimeError ('Unknown structured_outputs format received:' ,
196+ type (structured_outputs ))
197+
198+ self .debug_print (f"Available Output Tensors: { structured_outputs } " )
199+ self .debug_print (f"Chosen Output Tensor: { self ._output_tensors } " )
200+
156201 def _get_graph_func (
157202 self ,
158203 input_saved_model_dir ,
@@ -288,6 +333,10 @@ def _check_input_fn(func, name):
288333
289334 return graph_func
290335
336+ def debug_print (self , msg ):
337+ if self ._debug :
338+ print (f"[DEBUG] { msg } " )
339+
291340 def execute_benchmark (
292341 self ,
293342 batch_size ,
@@ -317,7 +366,22 @@ def execute_benchmark(
317366 @_force_gpu_resync
318367 @tf .function (jit_compile = use_xla )
319368 def infer_step (_batch_x ):
320- return self ._graph_func (_batch_x )
369+ output = self ._graph_func (_batch_x )
370+ return itemgetter (* self ._output_tensors )(output )
371+
372+ predicted_dict = defaultdict (lambda : [])
373+ expected_arr = []
374+
375+ def get_debug_output_shape_str (output ):
376+ if isinstance (output , (tuple , list )):
377+ return [t .shape for t in output ]
378+
379+ elif isinstance (output , dict ):
380+ return {k : v .shape for k , v in output .items ()}
381+
382+ else :
383+ return output .shape
384+
321385
322386 print ("\n Start inference ..." )
323387 for i , data_batch in enumerate (dataset ):
@@ -348,19 +412,62 @@ def infer_step(_batch_x):
348412 ))
349413
350414 if not skip_accuracy_testing :
351- self .process_model_output (
352- outputs = batch_preds ,
353- batch_y = batch_y ,
354- ** kwargs
355- )
415+ if i == 0 :
416+ self .debug_print ("=========== BEFORE PROCESSING ==========" )
417+ debug_batch_preds = get_debug_output_shape_str (batch_preds )
418+ self .debug_print (f"`batch_preds`: { debug_batch_preds } " )
419+ if batch_y is not None :
420+ self .debug_print (f"`batch_y` shape: { batch_y .shape } " )
421+
422+ batch_preds = self .process_model_output (batch_preds , ** kwargs )
423+
424+ if not isinstance (batch_preds , dict ):
425+ raise ValueError (
426+ f"`self.process_model_output` did not return a dict. " \
427+ f"Received: { type (batch_preds )} "
428+ )
429+
430+ if batch_y is not None :
431+ batch_y = batch_y .numpy ()
432+ if batch_y .shape [- 1 ] == 1 :
433+ batch_y = np .squeeze (batch_y , axis = - 1 )
434+
435+ if i == 0 :
436+ self .debug_print ("=========== AFTER PROCESSING ===========" )
437+ debug_batch_preds = get_debug_output_shape_str (batch_preds )
438+ self .debug_print (f"`batch_preds`: { debug_batch_preds } " )
439+ if batch_y is not None :
440+ self .debug_print (f"`batch_y` shape: { batch_y .shape } " )
441+ self .debug_print ("========================================" )
442+
443+ for key , value in batch_preds .items ():
444+ predicted_dict [key ].append (value )
445+
446+ if batch_y is not None :
447+ expected_arr .append (batch_y )
356448
357449 if (i + 1 ) >= num_iterations :
358450 break
359451
360452 if not skip_accuracy_testing :
453+ predicted_dict = {
454+ k : np .concatenate (v , axis = 0 )
455+ for k , v in predicted_dict .items ()
456+ }
457+ if expected_arr :
458+ expected_arr = np .concatenate (expected_arr , axis = 0 )
459+ else :
460+ expected_arr = np .array (expected_arr )
461+
462+ self .debug_print ("=========== BEFORE METRIC COMPUTATION ==========" )
463+ debug_predicted_dict = get_debug_output_shape_str (predicted_dict )
464+ self .debug_print (f"`predicted_dict`: { debug_predicted_dict } " )
465+ self .debug_print (f"`expected_arr` shape: { expected_arr .shape } " )
466+ self .debug_print ("========================================" )
467+
361468 results ['accuracy_metric' ] = self .compute_accuracy_metric (
362- batch_size = batch_size ,
363- steps_executed = steps_executed ,
469+ predictions = predicted_dict ,
470+ expected = expected_arr ,
364471 ** kwargs
365472 )
366473
0 commit comments