Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 64ff830

Browse files
DEKHTIARJonathanDEKHTIARJonathan
authored andcommitted
Test Batching get_force_data_on_gpu_fn
1 parent 535c132 commit 64ff830

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

tftrt/examples/dataloading_utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,10 @@ def get_force_data_on_gpu_fn(device="/gpu:0", use_xla=False):
7575
def force_data_on_gpu_fn(data):
7676
with tf.device(device):
7777
if isinstance(data, (list, tuple)):
78-
output_data = list()
79-
for t in data:
80-
output_data.append(tf.identity(t))
78+
return tf.identity_n(data)
8179
elif isinstance(data, dict):
82-
output_data = dict()
83-
for k, v in data.items():
84-
output_data[k] = tf.identity(v)
80+
return dict(zip(data.keys(), tf.identity_n(list(data.values()))))
8581
else:
86-
output_data = tf.identity(data)
87-
88-
return output_data
82+
return tf.identity(data)
8983

9084
return force_data_on_gpu_fn

0 commit comments

Comments
 (0)