Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 38e0681

Browse files
author
DEKHTIARJonathan
committed
Force Input DS on Device
1 parent 6451e02 commit 38e0681

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

tftrt/examples/benchmark_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from benchmark_utils import timed_section
2121

2222
from dataloading_utils import SyntheticDataset
23+
from dataloading_utils import ensure_dataset_on_gpu
2324
from dataloading_utils import get_dequeue_batch_fn
2425
from dataloading_utils import get_force_data_on_gpu_fn
2526

@@ -352,6 +353,8 @@ def execute_benchmark(self):
352353
f"synthetic dataset. Performance numbers will be "
353354
f"impacted.\nError: {str(e)}."
354355
)
356+
else:
357+
dataset = ensure_dataset_on_gpu(dataset, device="GPU:0")
355358

356359
@force_gpu_resync
357360
@tf.function(jit_compile=self._args.use_xla)

tftrt/examples/dataloading_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ def __iter__(self):
3636
yield data_batch
3737

3838

39+
def ensure_dataset_on_gpu(dataset, device):
40+
if device.lower() not in dataset._variant_tensor_attr.device.lower():
41+
return dataset.apply(
42+
tf.data.experimental.prefetch_to_device(
43+
device=device,
44+
buffer_size=tf.data.experimental.AUTOTUNE
45+
)
46+
)
47+
else:
48+
return dataset
49+
50+
3951
def get_dequeue_batch_fn(ds_iter):
4052

4153
@force_gpu_resync

0 commit comments

Comments
 (0)