|
8 | 8 | from benchmark_autotuner import auto_tf_func_tuner |
9 | 9 |
|
10 | 10 |
|
11 | | -class SyntheticDataset(object): |
12 | | - def __iter__(self): |
13 | | - data = 0 |
14 | | - |
15 | | - def __init__(self, dataset, device): |
16 | | - dataset = dataset.take(count=1) # loop over 1 batch |
17 | | - dataset = dataset.cache() |
18 | | - dataset = dataset.repeat() |
19 | | - dataset = dataset.prefetch( |
20 | | - buffer_size=tf.data.experimental.AUTOTUNE |
21 | | - ) |
22 | | - dataset = dataset.apply( |
23 | | - tf.data.experimental.prefetch_to_device( |
24 | | - device, |
25 | | - buffer_size=tf.data.experimental.AUTOTUNE |
26 | | - ) |
27 | | - ) |
28 | | - self._ds = dataset |
29 | | - self._data_batch = next(iter(dataset)) |
30 | | - |
31 | | - def __iter__(self): |
32 | | - return iter(self._ds) |
| 11 | +def SyntheticDataset(dataset, device): |
| 12 | + dataset = dataset.take(count=1) # loop over 1 batch |
| 13 | + dataset = dataset.cache() |
| 14 | + dataset = dataset.repeat() |
| 15 | + dataset = dataset.prefetch( |
| 16 | + buffer_size=tf.data.experimental.AUTOTUNE |
| 17 | + ) |
| 18 | + dataset = ensure_dataset_on_gpu(dataset, device) |
| 19 | + return dataset |
33 | 20 |
|
34 | 21 |
|
35 | 22 | def ensure_dataset_on_gpu(dataset, device): |
36 | | - if isinstance(dataset, SyntheticDataset): |
| 23 | + |
| 24 | + # ensuring no tensor dtype == int32 |
| 25 | + input_batch = next(iter(dataset)) |
| 26 | + if isinstance(input_batch, dict): |
| 27 | + input_batch = input_batch.values() |
| 28 | + elif not isinstance(input_batch, (tuple, list)): |
| 29 | + input_batch = [input_batch] |
| 30 | + |
| 31 | + if any([t.dtype == tf.int32 for t in input_batch]): |
| 32 | + print("[WARNING] The dataloader generates INT32 tensors. Prefetch to " |
| 33 | + "GPU not supported") |
37 | 34 | return dataset |
38 | 35 |
|
39 | 36 | try: |
|
0 commit comments