Skip to content

Commit 6e26fad

Browse files
author
The TensorFlow Datasets Authors
committed
Finish getting TfdsDataSource mock to iterate correctly.
The numpy_function has to return 1 output, so stack serialized outputs. This allows the lookup inside TfDataLoader to batch normally. PiperOrigin-RevId: 628080888
1 parent 465d3dd commit 6e26fad

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tensorflow_datasets/testing/mocking.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,13 @@ def _getitems(
116116
serialized: bool = False,
117117
) -> Sequence[Any]:
118118
"""Function to overwrite __getitems__ in data sources."""
119-
return [
119+
items = [
120120
_getitem(self, record_key, generator, serialized=serialized)
121121
for record_key in record_keys
122122
]
123+
if serialized:
124+
return np.array(items)
125+
return items
123126

124127

125128
def _deserialize_example_np(serialized_example, *, decoders=None):

0 commit comments

Comments
 (0)