Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 169e7cd

Browse files
author
DEKHTIARJonathan
committed
Error fix for DALIsets
1 parent b3b02f5 commit 169e7cd

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

tftrt/examples/dataloading_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,19 @@ def __iter__(self):
3737

3838

3939
def ensure_dataset_on_gpu(dataset, device):
40-
if device.lower() not in dataset._variant_tensor_attr.device.lower():
40+
try:
41+
ds_device = dataset._variant_tensor_attr.device.lower()
42+
except AttributeError:
43+
return dataset
44+
45+
if device.lower() not in ds_device:
4146
return dataset.apply(
4247
tf.data.experimental.prefetch_to_device(
4348
device=device,
4449
buffer_size=tf.data.experimental.AUTOTUNE
4550
)
4651
)
52+
4753
else:
4854
return dataset
4955

0 commit comments

Comments
 (0)