Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 73b6db5

Browse files
author
DEKHTIARJonathan
committed
Trimmed Mean Added to make throughput numbers more stable
1 parent fa1e35a commit 73b6db5

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

tftrt/examples/benchmark_args.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,21 @@ def __init__(self):
120120
self._parser.add_argument(
121121
"--num_warmup_iterations",
122122
type=int,
123-
default=100,
123+
default=200,
124124
help="Number of initial iterations skipped from timing."
125125
)
126126

127+
self._parser.add_argument(
128+
"--trim_mean_percentage",
129+
type=float,
130+
default=0.1,
131+
required=False,
132+
help="Percentage used to trim step timing distribution from both "
133+
"tails (fastest and slowest steps). 0.1 (default value) means that "
134+
"10% of the fastest and slowest iteration will be removed for "
135+
"model throughput computation."
136+
)
137+
127138
self._parser.add_argument(
128139
"--total_max_samples",
129140
type=int,

tftrt/examples/benchmark_runner.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from dataloading_utils import get_force_data_on_gpu_fn
2626

2727
import numpy as np
28+
import scipy as sp
29+
import scipy.stats
2830
import tensorflow as tf
2931

3032
from tensorflow.python.compiler.tensorrt import trt_convert as trt
@@ -500,11 +502,14 @@ def log_step(step_idx, display_every, iter_time, memcpyHtoD_time, dequeue_time):
500502

501503
metrics['Total GPU Time (s)'] = int(np.ceil(np.sum(iter_times)))
502504
metrics['Throughput (samples/sec)'] = (
503-
self._args.batch_size / np.mean(iter_times)
504-
)
505+
self._args.batch_size / sp.stats.trim_mean(
506+
iter_times, self._args.trim_mean_percentage))
505507

506508
def timing_metrics(time_arr, log_prefix):
507509
data = dict()
510+
data[f"{log_prefix} Trim Mean [{self._args.trim_mean_percentage * 100}%] (ms)"] = (
511+
sp.stats.trim_mean(time_arr, self._args.trim_mean_percentage) * 1000
512+
)
508513
data[f"{log_prefix} 99th_percentile (ms)"] = np.percentile(
509514
time_arr, q=99, interpolation='lower'
510515
) * 1000
@@ -522,9 +527,9 @@ def timing_metrics(time_arr, log_prefix):
522527

523528
def log_value(key, val):
524529
if isinstance(val, int):
525-
print(f"- {key:45s}: {val}")
530+
print(f"- {key:50s}: {val}")
526531
else:
527-
print(f"- {key:45s}: {val:.2f}")
532+
print(f"- {key:50s}: {val:.2f}")
528533

529534
for key, val in sorted(metrics.items()):
530535
if isinstance(val, dict):

0 commit comments

Comments
 (0)