Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit b9c1332

Browse files
Adding JSON export
1 parent c624e14 commit b9c1332

File tree

2 files changed

+47
-5
lines changed

2 files changed

+47
-5
lines changed

tftrt/examples/benchmark_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,14 @@ def __init__(self):
237237

238238
# =========================== DEBUG Flags ========================== #
239239

240+
self._parser.add_argument(
241+
"--export_metrics_json_path",
242+
type=str,
243+
default=None,
244+
help="If set, the script will export runtime metrics and arguments "
245+
"to the set location in JSON format for further processing."
246+
)
247+
240248
self._add_bool_argument(
241249
name="debug",
242250
default=False,

tftrt/examples/benchmark_runner.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import os
66

77
import abc
8+
import copy
9+
import json
810
import logging
911
import sys
1012
import time
@@ -104,6 +106,31 @@ def _debug_print(self, msg):
104106
if self._args.debug:
105107
print(f"[DEBUG] {msg}")
106108

109+
def _export_runtime_metrics_to_json(self, metric_dict):
110+
111+
metric_dict = {
112+
# Creating a copy to avoid modifying the original
113+
"results": copy.deepcopy(metric_dict),
114+
"runtime_arguments": vars(self._args)
115+
}
116+
117+
json_path = self._args.export_metrics_json_path
118+
if json_path is not None:
119+
try:
120+
with open(json_path, 'w') as json_f:
121+
json_string = json.dumps(
122+
metric_dict,
123+
default=lambda o: o.__dict__,
124+
sort_keys=True,
125+
indent=4
126+
)
127+
print(json_string, file=json_f)
128+
except Exception as e:
129+
print(
130+
"[ERROR] Impossible to save JSON File at path: "
131+
f"{json_path}.\nError: {str(e)}"
132+
)
133+
107134
def _get_graph_func(self):
108135
"""Retreives a frozen SavedModel and applies TF-TRT
109136
use_tftrt: bool, if true use TensorRT
@@ -389,16 +416,15 @@ def log_step(step_idx, display_every, iter_time):
389416

390417
with timed_section("Metric Computation"):
391418

419+
metrics = dict()
420+
392421
if not self._args.use_synthetic_data:
393422
metric, metric_units = self.evaluate_model(
394423
data_aggregator.predicted_dict,
395424
data_aggregator.expected_dict, bypass_data_to_eval
396425
)
397-
print(f"- {metric_units:35s}: {metric:.2f}")
426+
metrics["Metric"] = {metric_units: metric}
398427

399-
metrics = dict()
400-
401-
if not self._args.use_synthetic_data:
402428
metrics["Total Samples Processed"] = (
403429
data_aggregator.total_samples_processed
404430
)
@@ -419,10 +445,18 @@ def log_step(step_idx, display_every, iter_time):
419445
metrics['GPU Latency Min (ms)'] = np.min(run_times) * 1000
420446
metrics['GPU Latency Max (ms)'] = np.max(run_times) * 1000
421447

422-
for key, val in sorted(metrics.items()):
448+
self._export_runtime_metrics_to_json(metrics)
449+
450+
def log_value(key, val):
423451
if isinstance(val, int):
424452
print(f"- {key:35s}: {val}")
425453
else:
426454
print(f"- {key:35s}: {val:.2f}")
427455

456+
for key, val in sorted(metrics.items()):
457+
if isinstance(val, dict):
458+
log_value(*list(val.items())[0])
459+
else:
460+
log_value(key, val)
461+
428462
print() # visual spacing

0 commit comments

Comments
 (0)