3232from benchmark_logger import logging
3333
3434from benchmark_utils import DataAggregator
35+ from benchmark_utils import generate_json_metrics
3536from benchmark_utils import print_dict
3637from benchmark_utils import timed_section
3738
@@ -140,19 +141,12 @@ def _export_runtime_metrics_to_json(self, metric_dict):
140141 if file_path is None :
141142 return
142143
143- metric_dict = {
144- # Creating a copy to avoid modifying the original
145- "results" : copy .deepcopy (metric_dict ),
146- "runtime_arguments" : vars (self ._args )
147- }
144+ json_string = generate_json_metrics (
145+ metrics = metric_dict ,
146+ args = vars (self ._args ),
147+ )
148148
149149 with open (file_path , 'w' ) as json_f :
150- json_string = json .dumps (
151- metric_dict ,
152- default = lambda o : o .__dict__ ,
153- sort_keys = True ,
154- indent = 4
155- )
156150 print (json_string , file = json_f )
157151
158152 except Exception as e :
@@ -205,6 +199,34 @@ def _export_runtime_metrics_to_csv(self, metric_dict):
205199 except Exception as e :
206200 logging .error (f"An exception occured during export to CSV: { e } " )
207201
202+ def _upload_metrics_to_endpoint (self , metric_dict ):
203+
204+ try :
205+
206+ if self ._args .upload_metrics_endpoint is None :
207+ return
208+
209+ json_string = generate_json_metrics (
210+ metrics = metric_dict ,
211+ args = vars (self ._args ),
212+ )
213+
214+ headers = {"Content-Type" : "application/json" }
215+
216+ response = requests .put (
217+ endpoint , data = json .dumps (data ), headers = headers
218+ )
219+ response .raise_for_status ()
220+
221+ logging .info (
222+ "Metrics Uploaded to endpoint: "
223+ f"`{ self ._args .upload_metrics_endpoint } ` with experiment name: "
224+ f"`{ self ._args .experiment_name } `."
225+ )
226+
227+ except Exception as e :
228+ logging .error (f"An exception occured during export to JSON: { e } " )
229+
208230 def _get_graph_func (self ):
209231 """Retreives a frozen SavedModel and applies TF-TRT
210232 use_tftrt: bool, if true use TensorRT
@@ -587,9 +609,12 @@ def start_profiling():
587609 if not self ._args .use_synthetic_data :
588610 data_aggregator .aggregate_data (y_pred , y )
589611
590- if (not self ._args .debug_performance and
591- step_idx % self ._args .display_every !=
592- 0 ): # avoids double printing
612+ # yapf: disable
613+ if (
614+ not self ._args .debug_performance and
615+ # avoids double printing
616+ step_idx % self ._args .display_every != 0
617+ ):
593618 log_step (
594619 step_idx ,
595620 display_every = 1 , # force print
@@ -602,6 +627,7 @@ def start_profiling():
602627 dequeue_times [- self ._args .display_every :]
603628 ) * 1000
604629 )
630+ # yapf: enable
605631
606632 if step_idx >= 100 :
607633 stop_profiling ()
@@ -668,6 +694,7 @@ def timing_metrics(time_arr, log_prefix):
668694
669695 self ._export_runtime_metrics_to_json (metrics )
670696 self ._export_runtime_metrics_to_csv (metrics )
697+ self ._upload_metrics_to_endpoint (metrics )
671698
672699 def log_value (key , val ):
673700 if isinstance (val , (int , str )):
0 commit comments