Skip to content

Commit 9f7dbed

Browse files
pierrot0The TensorFlow Datasets Authors
authored andcommitted
Fix np decoding for Dataset feature.
PiperOrigin-RevId: 647270132
1 parent b7d054a commit 9f7dbed

File tree

2 files changed

+52
-37
lines changed

2 files changed

+52
-37
lines changed

tensorflow_datasets/core/features/dataset_feature.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from tensorflow_datasets.core.data_sources import python
2525
from tensorflow_datasets.core.features import feature as feature_lib
2626
from tensorflow_datasets.core.features import sequence_feature
27-
from tensorflow_datasets.core.features import tensor_feature
2827
from tensorflow_datasets.core.features import top_level_feature
2928
from tensorflow_datasets.core.utils import py_utils
3029
from tensorflow_datasets.core.utils import type_utils
@@ -66,7 +65,7 @@ class Dataset(sequence_feature.Sequence):
6665
6766
```python
6867
features=tfds.features.FeatureDict({
69-
'agent_id': np.object_,
68+
'agent_id': np.object_,
7069
'episode': tfds.features.Dataset({
7170
'observation': tfds.features.Image(),
7271
'reward': tfds.features.Image(),
@@ -176,23 +175,13 @@ def decode_example_np(
176175
flatten = self.feature._flatten # pylint: disable=protected-access
177176
nest = self.feature._nest # pylint: disable=protected-access
178177
flat_example = flatten(serialized_example)
179-
flat_features = flatten(self.feature)
180178
num_slices: int | None = None
181179

182-
# First discover the number of slices in the Dataset. Notably, it's possible
183-
# that tensors have to be reshaped. We call slice a record in the Dataset.
180+
# Discover the number of slices in the Dataset (ie: the outter dimension).
181+
# We call slice a record in the Dataset.
184182
# We don't use `example` to avoid confusion with the `serialized_example`.
185-
for i, feature in enumerate(flat_features):
186-
if isinstance(feature, tensor_feature.Tensor) and feature.shape:
187-
try:
188-
flat_example[i] = flat_example[i].reshape((-1,) + feature.shape)
189-
except ValueError as e:
190-
raise ValueError(
191-
"The length of all elements of one slice should be the same."
192-
) from e
193-
feature_num_slices = flat_example[i].shape[0]
194-
else:
195-
feature_num_slices = len(flat_example[i])
183+
for example_feature in flat_example:
184+
feature_num_slices = len(example_feature)
196185
if num_slices is None:
197186
num_slices = feature_num_slices
198187
else:

tensorflow_datasets/core/features/dataset_feature_test.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -492,11 +492,9 @@ def test_getattr(self):
492492
)
493493
self.assertEqual(feature.names, ['left', 'right'])
494494

495-
feature = feature_lib.Dataset(
496-
{
497-
'label': feature_lib.ClassLabel(names=['left', 'right']),
498-
}
499-
)
495+
feature = feature_lib.Dataset({
496+
'label': feature_lib.ClassLabel(names=['left', 'right']),
497+
})
500498
self.assertEqual(feature['label'].names, ['left', 'right'])
501499

502500
def test_metadata(self):
@@ -512,24 +510,52 @@ def test_metadata(self):
512510

513511
class DecodeExampleNpTest(testing.SubTestCase):
514512

515-
def test_top_level_feature(self):
516-
feature = feature_lib.Dataset(
517-
{'feature_name': feature_lib.Tensor(dtype=np.uint8, shape=(4, 2))}
518-
)
519-
example = {'feature_name': np.ones(shape=(24,), dtype=np.int32)}
520-
expected = [{'feature_name': np.ones(shape=(4, 2), dtype=np.int32)}] * 3
521-
self.assertAllEqualNested(feature.decode_example_np(example), expected)
522-
523-
def test_tensor_feature(self):
524-
feature = feature_lib.Dataset(
525-
feature_lib.Tensor(dtype=np.uint8, shape=(4, 2))
526-
)
527-
example = np.ones(shape=(24,), dtype=np.uint8)
528-
expected = [np.ones(shape=(4, 2), dtype=np.int32)] * 3
529-
self.assertAllEqualNested(feature.decode_example_np(example), expected)
513+
def test_representative_example(self):
514+
feature = feature_lib.FeaturesDict({
515+
'step_number': feature_lib.Tensor(dtype=np.int32, shape=()),
516+
'steps': feature_lib.Dataset({
517+
'tensor': feature_lib.Tensor(dtype=np.uint8, shape=(7, 8)),
518+
'strings': feature_lib.Tensor(dtype=np.str_, shape=(3,)),
519+
'bool': feature_lib.Tensor(dtype=np.bool_, shape=()),
520+
'obj': feature_lib.FeaturesDict({
521+
'a': feature_lib.Tensor(
522+
dtype=np.float32,
523+
shape=(5,),
524+
encoding=feature_lib.Encoding.ZLIB,
525+
),
526+
'b': feature_lib.Tensor(dtype=np.int32, shape=(6,)),
527+
}),
528+
'reward': feature_lib.Tensor(dtype=np.float32, shape=()),
529+
}),
530+
'timestamp': feature_lib.Tensor(dtype=np.int64, shape=()),
531+
})
532+
subdataset_size = 42
533+
example = {
534+
'step_number': 7,
535+
'steps': [
536+
{
537+
'tensor': np.ones(shape=(7, 8), dtype=np.uint8),
538+
'strings': ['foo', 'bar', 'baz'],
539+
'bool': True,
540+
'obj': {
541+
'a': np.zeros(shape=(5,), dtype=np.float32),
542+
'b': np.zeros(shape=(6,), dtype=np.int32),
543+
},
544+
'reward': np.float32(42.42),
545+
}
546+
for _ in range(subdataset_size)
547+
],
548+
'timestamp': 1234567890,
549+
}
550+
encoded_example = feature.encode_example(example)
551+
decoded_encoded_example = feature.decode_example_np(encoded_example)
552+
self.assertAllEqualNested(decoded_encoded_example, example)
530553

531554
def test_nested_dict(self):
532-
feature = feature_lib.Dataset({'a': {'b': np.int32}, 'b': np.str_})
555+
feature = feature_lib.Dataset({
556+
'a': {'b': np.int32},
557+
'b': np.str_,
558+
})
533559
example = {'a': {'b': [1, 2, 3]}, 'b': ['a', 'b', 'c']}
534560
expected = [
535561
{'a': {'b': 1}, 'b': 'a'},

0 commit comments

Comments
 (0)