This repository was archived by the owner on Feb 3, 2025. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +15
-0
lines changed Expand file tree Collapse file tree 2 files changed +15
-0
lines changed Original file line number Diff line number Diff line change 2020from benchmark_utils import timed_section
2121
2222from dataloading_utils import SyntheticDataset
23+ from dataloading_utils import ensure_dataset_on_gpu
2324from dataloading_utils import get_dequeue_batch_fn
2425from dataloading_utils import get_force_data_on_gpu_fn
2526
@@ -352,6 +353,8 @@ def execute_benchmark(self):
352353 f"synthetic dataset. Performance numbers will be "
353354 f"impacted.\n Error: { str (e )} ."
354355 )
356+ else :
357+ dataset = ensure_dataset_on_gpu (dataset , device = "GPU:0" )
355358
356359 @force_gpu_resync
357360 @tf .function (jit_compile = self ._args .use_xla )
Original file line number Diff line number Diff line change @@ -36,6 +36,18 @@ def __iter__(self):
3636 yield data_batch
3737
3838
39+ def ensure_dataset_on_gpu (dataset , device ):
40+ if device .lower () not in dataset ._variant_tensor_attr .device .lower ():
41+ return dataset .apply (
42+ tf .data .experimental .prefetch_to_device (
43+ device = device ,
44+ buffer_size = tf .data .experimental .AUTOTUNE
45+ )
46+ )
47+ else :
48+ return dataset
49+
50+
3951def get_dequeue_batch_fn (ds_iter ):
4052
4153 @force_gpu_resync
You can’t perform that action at this time.
0 commit comments