Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 1c3d83f

Browse files
author
DEKHTIARJonathan
committed
[Benchmarking Py] Bug Fix at input prefetch on GPU Compatibility Check
1 parent 62c24e2 commit 1c3d83f

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

tftrt/benchmarking-python/dataloading_utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,23 @@ def copy_on_device(data):
3434
return itertools.repeat(data_batch)
3535

3636

37+
def _validate_data_gpu_compatible(data):
38+
if isinstance(data, dict):
39+
return all([_validate_data_gpu_compatible(x) for x in data.values()])
40+
41+
elif isinstance(data, (tuple, list)):
42+
return all([_validate_data_gpu_compatible(x) for x in data])
43+
44+
else:
45+
return data.dtype != tf.int32
46+
47+
3748
def ensure_dataset_on_gpu(dataset, device):
3849

3950
# ensuring no tensor dtype == int32
4051
input_batch = next(iter(dataset))
41-
if isinstance(input_batch, dict):
42-
input_batch = input_batch.values()
43-
elif not isinstance(input_batch, (tuple, list)):
44-
input_batch = [input_batch]
4552

46-
if any([t.dtype == tf.int32 for t in input_batch]):
53+
if not _validate_data_gpu_compatible(input_batch):
4754
logging.warning(
4855
"The dataloader generates INT32 tensors. Prefetch to "
4956
"GPU not supported"

0 commit comments

Comments
 (0)