Skip to content

Commit 7051dea

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Add support for loading data in a specific file format
PiperOrigin-RevId: 704282560
1 parent 5d2c3e5 commit 7051dea

File tree

3 files changed

+95
-24
lines changed

3 files changed

+95
-24
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -563,9 +563,9 @@ def get_reference(
563563
data_dir=self.data_dir_root,
564564
)
565565

566-
def get_file_spec(self, split: str) -> str:
566+
def get_file_spec(self, split: str) -> str | None:
567567
"""Returns the file spec of the split."""
568-
split_info: splits_lib.SplitInfo = self.info.splits[split]
568+
split_info = self.info.splits[split]
569569
return split_info.file_spec(self.info.file_format)
570570

571571
def is_prepared(self) -> bool:
@@ -815,6 +815,7 @@ def as_data_source(
815815
*,
816816
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
817817
deserialize_method: decode.DeserializeMethod = decode.DeserializeMethod.DESERIALIZE_AND_DECODE,
818+
file_format: str | file_adapters.FileFormat | None = None,
818819
) -> ListOrTreeOrElem[Sequence[Any]]:
819820
"""Constructs an `ArrayRecordDataSource`.
820821
@@ -833,6 +834,9 @@ def as_data_source(
833834
the features. Decoding is only supported if the examples are tf
834835
examples. Note that if the deserialize_method method is other than
835836
PARSE_AND_DECODE, then the `decoders` argument is ignored.
837+
file_format: if the dataset is stored in multiple file formats, then this
838+
can be used to specify which format to use. If not provided, we will
839+
default to the first available format.
836840
837841
Returns:
838842
`Sequence` if `split`,
@@ -868,22 +872,31 @@ def as_data_source(
868872
"Dataset info file format is not set! For random access, one of the"
869873
f" following formats is required: {random_access_formats_msg}"
870874
)
871-
872875
suitable_formats = available_formats.intersection(random_access_formats)
873-
if suitable_formats:
876+
if not suitable_formats:
877+
raise NotImplementedError(unsupported_format_msg)
878+
879+
if file_format is not None:
880+
file_format = file_adapters.FileFormat.from_value(file_format)
881+
if file_format not in suitable_formats:
882+
raise ValueError(
883+
f"Requested file format {file_format} is not available for this"
884+
f" dataset. Available formats: {available_formats}"
885+
)
886+
chosen_format = file_format
887+
else:
874888
chosen_format = suitable_formats.pop()
875889
logging.info(
876890
"Found random access formats: %s. Chose to use %s. Overriding file"
877891
" format in the dataset info.",
878892
", ".join([f.name for f in suitable_formats]),
879893
chosen_format,
880894
)
881-
# Change the dataset info to read from a random access format.
882-
info.set_file_format(
883-
chosen_format, override=True, override_if_initialized=True
884-
)
885-
else:
886-
raise NotImplementedError(unsupported_format_msg)
895+
896+
# Change the dataset info to read from a random access format.
897+
info.set_file_format(
898+
chosen_format, override=True, override_if_initialized=True
899+
)
887900

888901
# Create a dataset for each of the given splits
889902
def build_single_data_source(split: str) -> Sequence[Any]:
@@ -924,6 +937,7 @@ def as_dataset(
924937
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
925938
read_config: read_config_lib.ReadConfig | None = None,
926939
as_supervised: bool = False,
940+
file_format: str | file_adapters.FileFormat | None = None,
927941
):
928942
# pylint: disable=line-too-long
929943
"""Constructs a `tf.data.Dataset`.
@@ -993,6 +1007,9 @@ def as_dataset(
9931007
a 2-tuple structure `(input, label)` according to
9941008
`builder.info.supervised_keys`. If `False`, the default, the returned
9951009
`tf.data.Dataset` will have a dictionary with all the features.
1010+
file_format: if the dataset is stored in multiple file formats, then this
1011+
argument can be used to specify the file format to load. If not
1012+
specified, the default file format is used.
9961013
9971014
Returns:
9981015
`tf.data.Dataset`, or if `split=None`, `dict<key: tfds.Split, value:
@@ -1026,6 +1043,7 @@ def as_dataset(
10261043
decoders=decoders,
10271044
read_config=read_config,
10281045
as_supervised=as_supervised,
1046+
file_format=file_format,
10291047
)
10301048
all_ds = tree.map_structure(build_single_dataset, split)
10311049
return all_ds
@@ -1038,19 +1056,29 @@ def _build_single_dataset(
10381056
decoders: TreeDict[decode.partial_decode.DecoderArg] | None,
10391057
read_config: read_config_lib.ReadConfig,
10401058
as_supervised: bool,
1059+
file_format: str | file_adapters.FileFormat | None = None,
10411060
) -> tf.data.Dataset:
10421061
"""as_dataset for a single split."""
10431062
wants_full_dataset = batch_size == -1
10441063
if wants_full_dataset:
10451064
batch_size = self.info.splits.total_num_examples or sys.maxsize
10461065

1066+
if file_format is not None:
1067+
file_format = file_adapters.FileFormat.from_value(file_format)
1068+
10471069
# Build base dataset
1048-
ds = self._as_dataset(
1049-
split=split,
1050-
shuffle_files=shuffle_files,
1051-
decoders=decoders,
1052-
read_config=read_config,
1053-
)
1070+
as_dataset_kwargs = {
1071+
"split": split,
1072+
"shuffle_files": shuffle_files,
1073+
"decoders": decoders,
1074+
"read_config": read_config,
1075+
}
1076+
# Not all dataset builder classes support file_format, so only pass it if
1077+
# it's supported.
1078+
if "file_format" in inspect.signature(self._as_dataset).parameters:
1079+
as_dataset_kwargs["file_format"] = file_format
1080+
ds = self._as_dataset(**as_dataset_kwargs)
1081+
10541082
# Auto-cache small datasets which are small enough to fit in memory.
10551083
if self._should_cache_ds(
10561084
split=split, shuffle_files=shuffle_files, read_config=read_config
@@ -1235,6 +1263,7 @@ def _as_dataset(
12351263
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
12361264
read_config: read_config_lib.ReadConfig | None = None,
12371265
shuffle_files: bool = False,
1266+
file_format: str | file_adapters.FileFormat | None = None,
12381267
) -> tf.data.Dataset:
12391268
"""Constructs a `tf.data.Dataset`.
12401269
@@ -1250,6 +1279,9 @@ def _as_dataset(
12501279
read_config: `tfds.ReadConfig`
12511280
shuffle_files: `bool`, whether to shuffle the input files. Optional,
12521281
defaults to `False`.
1282+
file_format: if the dataset is stored in multiple file formats, then this
1283+
argument can be used to specify the file format to load. If not
1284+
specified, the default file format is used.
12531285
12541286
Returns:
12551287
`tf.data.Dataset`
@@ -1487,6 +1519,10 @@ def __init__(
14871519

14881520
@functools.cached_property
14891521
def _example_specs(self):
1522+
if self.info.features is None:
1523+
raise ValueError(
1524+
f"Features are not set for dataset {self.name} in {self.data_dir}!"
1525+
)
14901526
return self.info.features.get_serialized_info()
14911527

14921528
def _as_dataset( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
@@ -1495,6 +1531,7 @@ def _as_dataset( # pytype: disable=signature-mismatch # overriding-parameter-t
14951531
decoders: TreeDict[decode.partial_decode.DecoderArg] | None,
14961532
read_config: read_config_lib.ReadConfig,
14971533
shuffle_files: bool,
1534+
file_format: file_adapters.FileFormat | None = None,
14981535
) -> tf.data.Dataset:
14991536
# Partial decoding
15001537
# TODO(epot): Should be moved inside `features.decode_example`
@@ -1508,10 +1545,15 @@ def _as_dataset( # pytype: disable=signature-mismatch # overriding-parameter-t
15081545
example_specs = self._example_specs
15091546
decoders = decoders # pylint: disable=self-assigning-variable
15101547

1548+
if features is None:
1549+
raise ValueError(
1550+
f"Features are not set for dataset {self.name} in {self.data_dir}!"
1551+
)
1552+
15111553
reader = reader_lib.Reader(
15121554
self.data_dir,
15131555
example_specs=example_specs,
1514-
file_format=self.info.file_format,
1556+
file_format=file_format or self.info.file_format,
15151557
)
15161558
decode_fn = functools.partial(features.decode_example, decoders=decoders)
15171559
return reader.read(

tensorflow_datasets/core/load.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from collections.abc import Iterable, Iterator, Mapping, Sequence
2121
import dataclasses
2222
import difflib
23+
import inspect
2324
import posixpath
2425
import re
2526
import textwrap
@@ -226,7 +227,7 @@ def _try_load_from_files_first(
226227
**builder_kwargs: Any,
227228
) -> bool:
228229
"""Returns True if files should be used rather than code."""
229-
if set(builder_kwargs) - {'version', 'config', 'data_dir'}:
230+
if set(builder_kwargs) - {'version', 'config', 'data_dir', 'file_format'}:
230231
return False # Has extra kwargs, requires original code.
231232
elif builder_kwargs.get('version') == 'experimental_latest':
232233
return False # Requested version requires original code
@@ -485,10 +486,13 @@ def _fetch_builder(
485486
data_dir: epath.PathLike | None,
486487
builder_kwargs: dict[str, Any] | None,
487488
try_gcs: bool,
489+
file_format: str | file_adapters.FileFormat | None = None,
488490
) -> dataset_builder.DatasetBuilder:
489491
"""Fetches the `tfds.core.DatasetBuilder` by name."""
490492
if builder_kwargs is None:
491493
builder_kwargs = {}
494+
if file_format is not None:
495+
builder_kwargs['file_format'] = file_format
492496
return builder(name, data_dir=data_dir, try_gcs=try_gcs, **builder_kwargs)
493497

494498

@@ -529,6 +533,7 @@ def load(
529533
download_and_prepare_kwargs: dict[str, Any] | None = None,
530534
as_dataset_kwargs: dict[str, Any] | None = None,
531535
try_gcs: bool = False,
536+
file_format: str | file_adapters.FileFormat | None = None,
532537
):
533538
# pylint: disable=line-too-long
534539
"""Loads the named dataset into a `tf.data.Dataset`.
@@ -636,6 +641,9 @@ def load(
636641
fully bypass GCS, please use `try_gcs=False` and
637642
`download_and_prepare_kwargs={'download_config':
638643
tfds.core.download.DownloadConfig(try_download_gcs=False)})`.
644+
file_format: if the dataset is stored in multiple file formats, then this
645+
argument can be used to specify the file format to load. If not specified,
646+
the default file format is used.
639647
640648
Returns:
641649
ds: `tf.data.Dataset`, the dataset requested, or if `split` is None, a
@@ -648,10 +656,10 @@ def load(
648656
Split-specific information is available in `ds_info.splits`.
649657
""" # fmt: skip
650658
dbuilder = _fetch_builder(
651-
name,
652-
data_dir,
653-
builder_kwargs,
654-
try_gcs,
659+
name=name,
660+
data_dir=data_dir,
661+
builder_kwargs=builder_kwargs,
662+
try_gcs=try_gcs,
655663
)
656664
_download_and_prepare_builder(dbuilder, download, download_and_prepare_kwargs)
657665

@@ -664,6 +672,8 @@ def load(
664672
as_dataset_kwargs.setdefault('decoders', decoders)
665673
as_dataset_kwargs.setdefault('shuffle_files', shuffle_files)
666674
as_dataset_kwargs.setdefault('read_config', read_config)
675+
if 'file_format' in inspect.signature(dbuilder.as_dataset).parameters:
676+
as_dataset_kwargs.setdefault('file_format', file_format)
667677

668678
ds = dbuilder.as_dataset(**as_dataset_kwargs)
669679
if with_info:

tensorflow_datasets/core/read_only_builder.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from etils import etree
3434
from tensorflow_datasets.core import dataset_builder
3535
from tensorflow_datasets.core import dataset_info
36+
from tensorflow_datasets.core import file_adapters
3637
from tensorflow_datasets.core import logging as tfds_logging
3738
from tensorflow_datasets.core import naming
3839
from tensorflow_datasets.core import registered
@@ -57,6 +58,7 @@ def __init__(
5758
builder_dir: epath.PathLike,
5859
*,
5960
info_proto: dataset_info_pb2.DatasetInfo | None = None,
61+
file_format: str | file_adapters.FileFormat | None = None,
6062
):
6163
"""Constructor.
6264
@@ -66,6 +68,8 @@ def __init__(
6668
info_proto: DatasetInfo describing the name, config, etc of the requested
6769
dataset. Note that this overwrites dataset info that may be present in
6870
builder_dir.
71+
file_format: The desired file format to use for the dataset. If not
72+
specified, the file format in the DatasetInfo is used.
6973
7074
Raises:
7175
FileNotFoundError: If the builder_dir does not exist.
@@ -74,6 +78,15 @@ def __init__(
7478
if not info_proto:
7579
info_proto = dataset_info.read_proto_from_builder_dir(builder_dir)
7680
self._info_proto = info_proto
81+
if file_format is not None:
82+
file_format = file_adapters.FileFormat.from_value(file_format)
83+
available_formats = set([self._info_proto.file_format])
84+
available_formats.update(self._info_proto.alternative_file_formats)
85+
if file_format.file_suffix not in available_formats:
86+
raise ValueError(
87+
f'File format {file_format.file_suffix} does not match the file'
88+
f' formats in the DatasetInfo: {sorted(available_formats)}.'
89+
)
7790

7891
self.name = info_proto.name
7992
self.VERSION = version_lib.Version(info_proto.version) # pylint: disable=invalid-name
@@ -92,6 +105,7 @@ def __init__(
92105
data_dir=builder_dir,
93106
config=builder_config,
94107
version=info_proto.version,
108+
file_format=file_format,
95109
)
96110
self.assert_is_not_blocked()
97111

@@ -154,6 +168,7 @@ def _download_and_prepare(self, **kwargs): # pylint: disable=arguments-differ
154168

155169
def builder_from_directory(
156170
builder_dir: epath.PathLike,
171+
file_format: str | file_adapters.FileFormat | None = None,
157172
) -> dataset_builder.DatasetBuilder:
158173
"""Loads a `tfds.core.DatasetBuilder` from the given generated dataset path.
159174
@@ -171,11 +186,13 @@ def builder_from_directory(
171186
Args:
172187
builder_dir: Path of the directory containing the dataset to read ( e.g.
173188
`~/tensorflow_datasets/mnist/3.0.0/`).
189+
file_format: The desired file format to use for the dataset. If not
190+
specified, the default file format in the DatasetInfo is used.
174191
175192
Returns:
176193
builder: `tfds.core.DatasetBuilder`, builder for dataset at the given path.
177194
"""
178-
return ReadOnlyBuilder(builder_dir=builder_dir)
195+
return ReadOnlyBuilder(builder_dir=builder_dir, file_format=file_format)
179196

180197

181198
def builder_from_directories(
@@ -308,7 +325,8 @@ def builder_from_files(
308325
f'and that it has been generated in: {data_dirs}. If the dataset has'
309326
' configs, you might have to specify the config name.'
310327
)
311-
return builder_from_directory(builder_dir)
328+
file_format = builder_kwargs.pop('file_format', None)
329+
return builder_from_directory(builder_dir, file_format=file_format)
312330

313331

314332
def _find_builder_dir(name: str, **builder_kwargs: Any) -> epath.Path | None:
@@ -339,6 +357,7 @@ def _find_builder_dir(name: str, **builder_kwargs: Any) -> epath.Path | None:
339357
version = str(version) if version else None
340358
config = builder_kwargs.pop('config', None)
341359
data_dir = builder_kwargs.pop('data_dir', None)
360+
_ = builder_kwargs.pop('file_format', None)
342361

343362
# Builder cannot be found if it uses:
344363
# * namespace

0 commit comments

Comments
 (0)