55import os
66
77import abc
8+ import copy
9+ import json
810import logging
911import sys
1012import time
@@ -104,6 +106,31 @@ def _debug_print(self, msg):
104106 if self ._args .debug :
105107 print (f"[DEBUG] { msg } " )
106108
109+ def _export_runtime_metrics_to_json (self , metric_dict ):
110+
111+ metric_dict = {
112+ # Creating a copy to avoid modifying the original
113+ "results" : copy .deepcopy (metric_dict ),
114+ "runtime_arguments" : vars (self ._args )
115+ }
116+
117+ json_path = self ._args .export_metrics_json_path
118+ if json_path is not None :
119+ try :
120+ with open (json_path , 'w' ) as json_f :
121+ json_string = json .dumps (
122+ metric_dict ,
123+ default = lambda o : o .__dict__ ,
124+ sort_keys = True ,
125+ indent = 4
126+ )
127+ print (json_string , file = json_f )
128+ except Exception as e :
129+ print (
130+ "[ERROR] Impossible to save JSON File at path: "
131+ f"{ json_path } .\n Error: { str (e )} "
132+ )
133+
107134 def _get_graph_func (self ):
108135 """Retreives a frozen SavedModel and applies TF-TRT
109136 use_tftrt: bool, if true use TensorRT
@@ -389,16 +416,15 @@ def log_step(step_idx, display_every, iter_time):
389416
390417 with timed_section ("Metric Computation" ):
391418
419+ metrics = dict ()
420+
392421 if not self ._args .use_synthetic_data :
393422 metric , metric_units = self .evaluate_model (
394423 data_aggregator .predicted_dict ,
395424 data_aggregator .expected_dict , bypass_data_to_eval
396425 )
397- print ( f"- { metric_units :35s } : { metric :.2f } " )
426+ metrics [ "Metric" ] = {metric_units : metric }
398427
399- metrics = dict ()
400-
401- if not self ._args .use_synthetic_data :
402428 metrics ["Total Samples Processed" ] = (
403429 data_aggregator .total_samples_processed
404430 )
@@ -419,10 +445,18 @@ def log_step(step_idx, display_every, iter_time):
419445 metrics ['GPU Latency Min (ms)' ] = np .min (run_times ) * 1000
420446 metrics ['GPU Latency Max (ms)' ] = np .max (run_times ) * 1000
421447
422- for key , val in sorted (metrics .items ()):
448+ self ._export_runtime_metrics_to_json (metrics )
449+
450+ def log_value (key , val ):
423451 if isinstance (val , int ):
424452 print (f"- { key :35s} : { val } " )
425453 else :
426454 print (f"- { key :35s} : { val :.2f} " )
427455
456+ for key , val in sorted (metrics .items ()):
457+ if isinstance (val , dict ):
458+ log_value (* list (val .items ())[0 ])
459+ else :
460+ log_value (key , val )
461+
428462 print () # visual spacing
0 commit comments