Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 51b16f9

Browse files
author
DEKHTIARJonathan
committed
Reject prefetch on device with INT32 inputs
1 parent f49503c commit 51b16f9

File tree

1 file changed

+20
-23
lines changed

1 file changed

+20
-23
lines changed

tftrt/benchmarking-python/dataloading_utils.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,29 @@
88
from benchmark_autotuner import auto_tf_func_tuner
99

1010

11-
class SyntheticDataset(object):
12-
def __iter__(self):
13-
data = 0
14-
15-
def __init__(self, dataset, device):
16-
dataset = dataset.take(count=1) # loop over 1 batch
17-
dataset = dataset.cache()
18-
dataset = dataset.repeat()
19-
dataset = dataset.prefetch(
20-
buffer_size=tf.data.experimental.AUTOTUNE
21-
)
22-
dataset = dataset.apply(
23-
tf.data.experimental.prefetch_to_device(
24-
device,
25-
buffer_size=tf.data.experimental.AUTOTUNE
26-
)
27-
)
28-
self._ds = dataset
29-
self._data_batch = next(iter(dataset))
30-
31-
def __iter__(self):
32-
return iter(self._ds)
11+
def SyntheticDataset(dataset, device):
12+
dataset = dataset.take(count=1) # loop over 1 batch
13+
dataset = dataset.cache()
14+
dataset = dataset.repeat()
15+
dataset = dataset.prefetch(
16+
buffer_size=tf.data.experimental.AUTOTUNE
17+
)
18+
dataset = ensure_dataset_on_gpu(dataset, device)
19+
return dataset
3320

3421

3522
def ensure_dataset_on_gpu(dataset, device):
36-
if isinstance(dataset, SyntheticDataset):
23+
24+
# ensuring no tensor dtype == int32
25+
input_batch = next(iter(dataset))
26+
if isinstance(input_batch, dict):
27+
input_batch = input_batch.values()
28+
elif not isinstance(input_batch, (tuple, list)):
29+
input_batch = [input_batch]
30+
31+
if any([t.dtype == tf.int32 for t in input_batch]):
32+
print("[WARNING] The dataloader generates INT32 tensors. Prefetch to "
33+
"GPU not supported")
3734
return dataset
3835

3936
try:

0 commit comments

Comments
 (0)