Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 903e059

Browse files
author
DEKHTIARJonathan
committed
[Benchmark-Py] Release 1.0.1 - Remove autotuning on get_dequeue_batch_fn in order to fix DALIDataset patch
1 parent f5074e2 commit 903e059

File tree

6 files changed

+47
-40
lines changed

6 files changed

+47
-40
lines changed

tftrt/benchmarking-python/CHANGELOG.md

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,23 @@ Description of the change
3737
existing metrics or models scripts in a way that would make
3838
metrics not comparable between minor releases.
3939

40-
- **Patch Version:** Changes that are expected to have no change to the operation
41-
of the benchmark nor the way metrics are calculated.
42-
Basically these changes are transparent for the user.
40+
- **Patch Version:** Changes that are expected to have no change to the
41+
operation of the benchmark nor the way metrics are
42+
calculated. Basically these changes are transparent for the
43+
user.
4344

4445
# Versions
4546

4647
<!-- YOU CAN EDIT FROM HERE -->
4748

49+
## [1.0.1] - 2022.07.25 - @DEKHTIARJonathan
50+
51+
Removing AutoTuning on `get_dequeue_batch_fn` because DALIDataset was not
52+
respecting the limit on the number of batches.
53+
54+
It should not impact the benchmark results, most of the time, the autotuner was
55+
selecting the eager version anyway.
56+
4857
## [1.0.0] - 2022.07.20 - @DEKHTIARJonathan
4958

5059
Initial Versioning Release.

tftrt/benchmarking-python/benchmark_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
# The `__version__` number shall be updated everytime core benchmarking files
4141
# are updated.
4242
# Please update CHANGELOG.md with a description of what this version changed.
43-
__version__ = "1.0.0"
43+
__version__ = "1.0.1"
4444

4545
__all__ = ["__version__", "BaseBenchmarkRunner"]
4646

tftrt/benchmarking-python/benchmark_utils.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -204,36 +204,3 @@ def aggregate_data(self, y_pred, y):
204204
f"Expected: {y[key].shape}"
205205
)
206206
self._expected[key][idx_start:idx_stop] = y[key]
207-
208-
209-
def patch_dali_dataset(dataset):
210-
import nvidia.dali.plugin.tf as dali_tf
211-
212-
if not isinstance(dataset, dali_tf.DALIDataset):
213-
raise TypeError(
214-
"Dataset supplied should be an instance of `DALIDataset`."
215-
f"Received: `{type(dataset)}`"
216-
)
217-
218-
def take(self, limit):
219-
220-
class _Dataset(self.__class__):
221-
222-
def __init__(self, _ds, _limit):
223-
self._ds = _ds
224-
self._limit = _limit
225-
226-
def __iter__(self):
227-
idx = 0
228-
for data in self._ds:
229-
if idx >= self._limit:
230-
break
231-
yield data
232-
idx += 1
233-
234-
return _Dataset(self, limit)
235-
236-
# Monkey Patch
237-
dataset.__class__.take = take
238-
239-
return dataset

tftrt/benchmarking-python/dataloading_utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import tensorflow as tf
88

99
from benchmark_autotuner import auto_tf_func_tuner
10+
from benchmark_utils import force_gpu_resync
1011

1112

1213
def SyntheticDataset(dataset, device):
@@ -69,7 +70,7 @@ def ensure_dataset_on_gpu(dataset, device):
6970

7071
def get_dequeue_batch_fn(ds_iter, use_xla=False, use_synthetic_data=False):
7172

72-
@auto_tf_func_tuner(use_xla=use_xla, use_synthetic_data=use_synthetic_data)
73+
@force_gpu_resync
7374
def dequeue_batch_fn():
7475
"""This function should not use tf.function().
7576
It would create two unwanted effects:
@@ -98,3 +99,33 @@ def force_data_on_gpu_fn(data):
9899
return tf.identity(data)
99100

100101
return force_data_on_gpu_fn
102+
103+
104+
def patch_dali_dataset(dataset):
105+
import nvidia.dali.plugin.tf as dali_tf
106+
107+
if not isinstance(dataset, dali_tf.DALIDataset):
108+
raise TypeError(
109+
"Dataset supplied should be an instance of `DALIDataset`."
110+
f"Received: `{type(dataset)}`"
111+
)
112+
113+
def take(self, limit):
114+
115+
class _Dataset(self.__class__):
116+
117+
def __init__(self, _ds, _limit):
118+
self._ds = _ds
119+
self._limit = _limit
120+
121+
def __iter__(self):
122+
ds_iter = iter(self._ds)
123+
for idx in tf.range(self._limit):
124+
yield next(ds_iter)
125+
126+
return _Dataset(self, limit)
127+
128+
# Monkey Patch
129+
dataset.__class__.take = take
130+
131+
return dataset

tftrt/benchmarking-python/nvidia_examples/nnunet2d_tf2/infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
from benchmark_args import BaseCommandLineAPI
4444
from benchmark_runner import BaseBenchmarkRunner
45-
from benchmark_utils import patch_dali_dataset
45+
from dataloading_utils import patch_dali_dataset
4646

4747

4848
class CommandLineAPI(BaseCommandLineAPI):

tftrt/benchmarking-python/nvidia_examples/nnunet3d_tf2/infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
from benchmark_args import BaseCommandLineAPI
4444
from benchmark_runner import BaseBenchmarkRunner
45-
from benchmark_utils import patch_dali_dataset
45+
from dataloading_utils import patch_dali_dataset
4646

4747

4848
class CommandLineAPI(BaseCommandLineAPI):

0 commit comments

Comments
 (0)