Skip to content

Commit acdccd4

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Add support for specifying how data should be deserialized in tfds.data_source
If you're interested in the raw bytes or the deserialized but not decoded examples, then you can set the deserialize_method parameter accordingly. PiperOrigin-RevId: 674218072
1 parent f62f596 commit acdccd4

File tree

8 files changed

+110
-17
lines changed

8 files changed

+110
-17
lines changed

tensorflow_datasets/core/data_sources/array_record.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class ArrayRecordDataSource(base.BaseDataSource):
4747
decoders: Optional[type_utils.TreeDict[decode.partial_decode.DecoderArg]] = (
4848
None
4949
)
50+
deserialize_method: decode.DeserializeMethod = (
51+
decode.DeserializeMethod.DESERIALIZE_AND_DECODE
52+
)
5053
# In order to lazy load array_record, we don't load
5154
# `array_record_data_source.ArrayRecordDataSource` here.
5255
data_source: Any = dataclasses.field(init=False)

tensorflow_datasets/core/data_sources/base.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,34 @@ class BaseDataSource(MappingView, Sequence):
6868
split: The split to load in the data source.
6969
decoders: Optional decoders for decoding.
7070
data_source: The underlying data source to initialize in the __post_init__.
71+
deserialize_method: How to deserialize the bytes that are read before
72+
returning.
7173
"""
7274

7375
dataset_info: dataset_info_lib.DatasetInfo
7476
split: splits_lib.Split | None = None
7577
decoders: type_utils.TreeDict[decode.partial_decode.DecoderArg] | None = None
7678
data_source: DataSource[Any] = dataclasses.field(init=False)
79+
deserialize_method: decode.DeserializeMethod = (
80+
decode.DeserializeMethod.DESERIALIZE_AND_DECODE
81+
)
82+
83+
def _deserialize(self, record: Any) -> Any:
84+
match self.deserialize_method:
85+
case decode.DeserializeMethod.RAW_BYTES:
86+
return record
87+
case decode.DeserializeMethod.DESERIALIZE_NO_DECODE:
88+
if file_format := self.dataset_info.file_format:
89+
return file_format.deserialize(record)
90+
raise ValueError('No file format set, cannot deserialize bytes!')
91+
case decode.DeserializeMethod.DESERIALIZE_AND_DECODE:
92+
if features := self.dataset_info.features:
93+
return features.deserialize_example_np(record, decoders=self.decoders) # pylint: disable=attribute-error
94+
raise ValueError('No features set, cannot decode example!')
7795

7896
def __getitem__(self, key: SupportsIndex) -> Any:
7997
record = self.data_source[key.__index__()]
80-
return self.dataset_info.features.deserialize_example_np(
81-
record, decoders=self.decoders
82-
)
98+
return self._deserialize(record)
8399

84100
def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
85101
"""Retrieves items by batch.
@@ -98,17 +114,12 @@ def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
98114
if not keys:
99115
return []
100116
records = self.data_source.__getitems__(keys)
101-
features = self.dataset_info.features
102117
if len(keys) != len(records):
103118
raise IndexError(
104-
f'Requested {len(keys)} records but got'
105-
f' {len(records)} records.'
119+
f'Requested {len(keys)} records but got {len(records)} records.'
106120
f'{keys=}, {records=}'
107121
)
108-
return [
109-
features.deserialize_example_np(record, decoders=self.decoders)
110-
for record in records
111-
]
122+
return [self._deserialize(record) for record in records]
112123

113124
def __repr__(self) -> str:
114125
decoders_repr = (

tensorflow_datasets/core/dataset_builder.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,7 @@ def as_data_source(
799799
split: Optional[Tree[splits_lib.SplitArg]] = None,
800800
*,
801801
decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]] = None,
802+
deserialize_method: decode.DeserializeMethod = decode.DeserializeMethod.DESERIALIZE_AND_DECODE,
802803
) -> ListOrTreeOrElem[Sequence[Any]]:
803804
"""Constructs an `ArrayRecordDataSource`.
804805
@@ -812,6 +813,11 @@ def as_data_source(
812813
customized feature keys need to be present. See [the
813814
guide](https://github.com/tensorflow/datasets/blob/master/docs/decode.md)
814815
for more info.
816+
deserialize_method: Whether the read examples should be deserialized
817+
and/or decoded. If not specified, it'll deserialize the data and decode
818+
the features. Decoding is only supported if the examples are tf
819+
examples. Note that if the deserialize_method method is other than
820+
PARSE_AND_DECODE, then the `decoders` argument is ignored.
815821
816822
Returns:
817823
`Sequence` if `split`,
@@ -866,13 +872,27 @@ def as_data_source(
866872

867873
# Create a dataset for each of the given splits
868874
def build_single_data_source(split: str) -> Sequence[Any]:
875+
if info.file_format is None:
876+
raise ValueError(
877+
"Dataset info file format is not set! For random access, one of the"
878+
f" following formats is required: {random_access_formats_msg}"
879+
)
880+
869881
match info.file_format:
870882
case file_adapters.FileFormat.ARRAY_RECORD:
871883
return array_record.ArrayRecordDataSource(
872-
info, split=split, decoders=decoders
884+
info,
885+
split=split,
886+
decoders=decoders,
887+
deserialize_method=deserialize_method,
873888
)
874889
case file_adapters.FileFormat.PARQUET:
875-
return parquet.ParquetDataSource(info, split=split, decoders=decoders)
890+
return parquet.ParquetDataSource(
891+
info,
892+
split=split,
893+
decoders=decoders,
894+
deserialize_method=deserialize_method,
895+
)
876896
case _:
877897
raise NotImplementedError(unsupported_format_msg)
878898

tensorflow_datasets/core/decode/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
"""Decoder public API."""
1717

1818
from tensorflow_datasets.core.decode.base import Decoder
19+
from tensorflow_datasets.core.decode.base import DeserializeMethod
1920
from tensorflow_datasets.core.decode.base import make_decoder
2021
from tensorflow_datasets.core.decode.base import SkipDecoding
2122
from tensorflow_datasets.core.decode.partial_decode import PartialDecoding
2223

2324
__all__ = [
2425
'Decoder',
2526
'make_decoder',
27+
'DeserializeMethod',
2628
'PartialDecoding',
2729
'SkipDecoding',
2830
]

tensorflow_datasets/core/decode/base.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""Base decoders."""
1717

1818
import abc
19+
import enum
1920
import functools
2021

2122
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
@@ -210,3 +211,30 @@ def decorated(*args, **kwargs):
210211
return decorated
211212

212213
return decorator
214+
215+
216+
class DeserializeMethod(enum.Enum):
217+
"""How to deserialize the bytes that are read before returning.
218+
219+
When reading examples from a source (e.g., a file), we consider 2 phases in
220+
parsing the raw data:
221+
222+
1. Deserialize: deserializes raw bytes into an object. Typically it will be
223+
deserialized into a `tf.train.Example`.
224+
225+
2. Decode: A `tf.train.Example` might encode information (e.g., a bytes
226+
feature encodes an image or a int64 list encodes a tensor). The second
227+
phase decodes the encoded information.
228+
229+
DESERIALIZE_AND_DECODE: deserialize the raw bytes to tf example (if file
230+
format doesn't have a custom encoding) and then decode the features. Note
231+
that how and what is decoded can typically be overriden with `decoders`.
232+
DESERIALIZE_NO_DECODE: parse the raw bytes to tf example (if file format
233+
doesn't have a custom encoding). If this parse method is used, then all
234+
decoders are ignored.
235+
RAW_BYTES: don't parse nor decode, but return the raw bytes.
236+
"""
237+
238+
DESERIALIZE_AND_DECODE = 'deserialize_and_decode'
239+
DESERIALIZE_NO_DECODE = 'deserialize_no_decode'
240+
RAW_BYTES = 'raw_bytes'

tensorflow_datasets/core/features/feature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ def decode_example(self, tfexample_data):
765765
def decode_example_np(
766766
self, example_data: type_utils.NpArrayOrScalar
767767
) -> type_utils.NpArrayOrScalar | None:
768-
"""Encode the feature dict into NumPy-compatible input.
768+
"""Decode the example data into NumPy-compatible input.
769769
770770
Args:
771771
example_data: Value to convert to NumPy.

tensorflow_datasets/core/file_adapters.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,20 @@
2525
import re
2626
from typing import Any, ClassVar, Type, TypeVar
2727

28-
from etils import epath
29-
from tensorflow_datasets.core.utils import file_utils
30-
from tensorflow_datasets.core.utils import type_utils
28+
from etils import epy
3129
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_module
3230
from tensorflow_datasets.core.utils.lazy_imports_utils import parquet as pq
3331
from tensorflow_datasets.core.utils.lazy_imports_utils import pyarrow as pa
3432
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
3533

34+
with epy.lazy_imports():
35+
# pylint: disable=g-import-not-at-top
36+
from etils import epath
37+
from tensorflow_datasets.core.utils import file_utils
38+
from tensorflow_datasets.core.utils import type_utils
39+
40+
# pylint: enable=g-import-not-at-top
41+
3642
ExamplePositions = list[Any]
3743
T = TypeVar('T')
3844

@@ -52,6 +58,10 @@ class FileFormat(enum.Enum):
5258
def file_suffix(self) -> str:
5359
return ADAPTER_FOR_FORMAT[self].FILE_SUFFIX
5460

61+
def deserialize(self, raw_example: bytes) -> Any:
62+
"""Deserializes bytes into an object, but does not decode features."""
63+
return ADAPTER_FOR_FORMAT[self].deserialize(raw_example)
64+
5565
@classmethod
5666
def with_random_access(cls) -> set[FileFormat]:
5767
"""File formats with random access."""
@@ -146,6 +156,17 @@ def write_examples(
146156
"""
147157
raise NotImplementedError()
148158

159+
@classmethod
160+
def deserialize(cls, raw_example: bytes) -> Any:
161+
"""Returns the deserialized example, but does not decode features.
162+
163+
If custom serialization is used, override this method in the file adapter.
164+
165+
Args:
166+
raw_example: the bytes read from the source that should be deserialized.
167+
"""
168+
return tf.train.Example.FromString(raw_example)
169+
149170

150171
class TfRecordFileAdapter(FileAdapter):
151172
"""File adapter for TFRecord file format."""

tensorflow_datasets/core/load.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,7 @@ def data_source(
705705
data_dir: Union[None, str, os.PathLike] = None, # pylint: disable=g-bare-generic
706706
download: bool = True,
707707
decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]] = None,
708+
deserialize_method: decode.DeserializeMethod = decode.DeserializeMethod.DESERIALIZE_AND_DECODE,
708709
builder_kwargs: Optional[Dict[str, Any]] = None,
709710
download_and_prepare_kwargs: Optional[Dict[str, Any]] = None,
710711
try_gcs: bool = False,
@@ -777,6 +778,11 @@ def data_source(
777778
customized feature keys need to be present. See [the
778779
guide](https://github.com/tensorflow/datasets/blob/master/docs/decode.md)
779780
for more info.
781+
deserialize_method: Whether the read examples should be deserialized and/or
782+
decoded. If not specified, it'll deserialize the data and decode the
783+
features. Decoding is only supported if the examples are tf examples.
784+
Note that if the parse method is other than PARSE_AND_DECODE, then the
785+
`decoders` argument is ignored.
780786
builder_kwargs: `dict` (optional), keyword arguments to be passed to the
781787
`tfds.core.DatasetBuilder` constructor. `data_dir` will be passed through
782788
by default.
@@ -807,7 +813,9 @@ def data_source(
807813
try_gcs,
808814
)
809815
_download_and_prepare_builder(dbuilder, download, download_and_prepare_kwargs)
810-
return dbuilder.as_data_source(split=split, decoders=decoders)
816+
return dbuilder.as_data_source(
817+
split=split, decoders=decoders, deserialize_method=deserialize_method
818+
)
811819

812820

813821
def _get_all_versions(

0 commit comments

Comments
 (0)