Skip to content

Commit 4721f7d

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Move file_format parameter from _as_dataset to ReadConfig because this is breaking external code that overrides _as_dataset.
PiperOrigin-RevId: 704667761
1 parent 8807a05 commit 4721f7d

File tree

3 files changed

+47
-44
lines changed

3 files changed

+47
-44
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,6 @@ def as_dataset(
937937
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
938938
read_config: read_config_lib.ReadConfig | None = None,
939939
as_supervised: bool = False,
940-
file_format: str | file_adapters.FileFormat | None = None,
941940
):
942941
# pylint: disable=line-too-long
943942
"""Constructs a `tf.data.Dataset`.
@@ -1007,9 +1006,6 @@ def as_dataset(
10071006
a 2-tuple structure `(input, label)` according to
10081007
`builder.info.supervised_keys`. If `False`, the default, the returned
10091008
`tf.data.Dataset` will have a dictionary with all the features.
1010-
file_format: if the dataset is stored in multiple file formats, then this
1011-
argument can be used to specify the file format to load. If not
1012-
specified, the default file format is used.
10131009
10141010
Returns:
10151011
`tf.data.Dataset`, or if `split=None`, `dict<key: tfds.Split, value:
@@ -1043,7 +1039,6 @@ def as_dataset(
10431039
decoders=decoders,
10441040
read_config=read_config,
10451041
as_supervised=as_supervised,
1046-
file_format=file_format,
10471042
)
10481043
all_ds = tree.map_structure(build_single_dataset, split)
10491044
return all_ds
@@ -1056,28 +1051,19 @@ def _build_single_dataset(
10561051
decoders: TreeDict[decode.partial_decode.DecoderArg] | None,
10571052
read_config: read_config_lib.ReadConfig,
10581053
as_supervised: bool,
1059-
file_format: str | file_adapters.FileFormat | None = None,
10601054
) -> tf.data.Dataset:
10611055
"""as_dataset for a single split."""
10621056
wants_full_dataset = batch_size == -1
10631057
if wants_full_dataset:
10641058
batch_size = self.info.splits.total_num_examples or sys.maxsize
10651059

1066-
if file_format is not None:
1067-
file_format = file_adapters.FileFormat.from_value(file_format)
1068-
10691060
# Build base dataset
1070-
as_dataset_kwargs = {
1071-
"split": split,
1072-
"shuffle_files": shuffle_files,
1073-
"decoders": decoders,
1074-
"read_config": read_config,
1075-
}
1076-
# Not all dataset builder classes support file_format, so only pass it if
1077-
# it's supported.
1078-
if "file_format" in inspect.signature(self._as_dataset).parameters:
1079-
as_dataset_kwargs["file_format"] = file_format
1080-
ds = self._as_dataset(**as_dataset_kwargs)
1061+
ds = self._as_dataset(
1062+
split=split,
1063+
shuffle_files=shuffle_files,
1064+
decoders=decoders,
1065+
read_config=read_config,
1066+
)
10811067

10821068
# Auto-cache small datasets which are small enough to fit in memory.
10831069
if self._should_cache_ds(
@@ -1263,7 +1249,6 @@ def _as_dataset(
12631249
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
12641250
read_config: read_config_lib.ReadConfig | None = None,
12651251
shuffle_files: bool = False,
1266-
file_format: str | file_adapters.FileFormat | None = None,
12671252
) -> tf.data.Dataset:
12681253
"""Constructs a `tf.data.Dataset`.
12691254
@@ -1279,9 +1264,6 @@ def _as_dataset(
12791264
read_config: `tfds.ReadConfig`
12801265
shuffle_files: `bool`, whether to shuffle the input files. Optional,
12811266
defaults to `False`.
1282-
file_format: if the dataset is stored in multiple file formats, then this
1283-
argument can be used to specify the file format to load. If not
1284-
specified, the default file format is used.
12851267
12861268
Returns:
12871269
`tf.data.Dataset`
@@ -1525,14 +1507,16 @@ def _example_specs(self):
15251507
)
15261508
return self.info.features.get_serialized_info()
15271509

1528-
def _as_dataset( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
1510+
def _as_dataset(
15291511
self,
15301512
split: splits_lib.Split,
1531-
decoders: TreeDict[decode.partial_decode.DecoderArg] | None,
1532-
read_config: read_config_lib.ReadConfig,
1533-
shuffle_files: bool,
1534-
file_format: file_adapters.FileFormat | None = None,
1513+
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
1514+
read_config: read_config_lib.ReadConfig | None = None,
1515+
shuffle_files: bool = False,
15351516
) -> tf.data.Dataset:
1517+
if read_config is None:
1518+
read_config = read_config_lib.ReadConfig()
1519+
15361520
# Partial decoding
15371521
# TODO(epot): Should be moved inside `features.decode_example`
15381522
if isinstance(decoders, decode.PartialDecoding):
@@ -1550,10 +1534,18 @@ def _as_dataset( # pytype: disable=signature-mismatch # overriding-parameter-t
15501534
f"Features are not set for dataset {self.name} in {self.data_dir}!"
15511535
)
15521536

1537+
file_format = (
1538+
read_config.file_format
1539+
or self.info.file_format
1540+
or file_adapters.DEFAULT_FILE_FORMAT
1541+
)
1542+
if file_format is not None:
1543+
file_format = file_adapters.FileFormat.from_value(file_format)
1544+
15531545
reader = reader_lib.Reader(
15541546
self.data_dir,
15551547
example_specs=example_specs,
1556-
file_format=file_format or self.info.file_format,
1548+
file_format=file_format,
15571549
)
15581550
decode_fn = functools.partial(features.decode_example, decoders=decoders)
15591551
return reader.read(

tensorflow_datasets/core/load.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from collections.abc import Iterable, Iterator, Mapping, Sequence
2121
import dataclasses
2222
import difflib
23-
import inspect
2423
import posixpath
2524
import re
2625
import textwrap
@@ -671,9 +670,10 @@ def load(
671670
as_dataset_kwargs.setdefault('batch_size', batch_size)
672671
as_dataset_kwargs.setdefault('decoders', decoders)
673672
as_dataset_kwargs.setdefault('shuffle_files', shuffle_files)
673+
if file_format is not None:
674+
read_config = read_config or read_config_lib.ReadConfig()
675+
read_config = read_config.replace(file_format=file_format)
674676
as_dataset_kwargs.setdefault('read_config', read_config)
675-
if 'file_format' in inspect.signature(dbuilder.as_dataset).parameters:
676-
as_dataset_kwargs.setdefault('file_format', file_format)
677677

678678
ds = dbuilder.as_dataset(**as_dataset_kwargs)
679679
if with_info:

tensorflow_datasets/core/utils/read_config.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
from __future__ import annotations
1919

20+
from collections.abc import Sequence
2021
import dataclasses
21-
from typing import Callable, Optional, Sequence, Union, cast
22+
from typing import Callable, cast
2223

24+
from tensorflow_datasets.core import file_adapters
2325
from tensorflow_datasets.core.utils import shard_utils
2426
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
2527

@@ -91,36 +93,45 @@ class ReadConfig:
9193
False if input files have been tempered with and they don't mind missing
9294
records or have too many of them.
9395
override_buffer_size: number of bytes to pass to file readers for buffering.
96+
file_format: if the dataset is stored in multiple file formats, then this
97+
argument can be used to specify the file format to load. If not specified,
98+
the default file format is used.
9499
"""
95100
# pyformat: enable
96101

97102
# General tf.data.Dataset parametters
98-
options: Optional[tf.data.Options] = None
103+
options: tf.data.Options | None = None
99104
try_autocache: bool = True
100105
repeat_filenames: bool = False
101106
add_tfds_id: bool = False
102107
# tf.data.Dataset.shuffle parameters
103-
shuffle_seed: Optional[int] = None
104-
shuffle_reshuffle_each_iteration: Optional[bool] = None
108+
shuffle_seed: int | None = None
109+
shuffle_reshuffle_each_iteration: bool | None = None
105110
# Interleave parameters
106111
# Ideally, we should switch interleave values to None to dynamically set
107112
# those value depending on the user system. However, this would make the
108113
# generation order non-deterministic accross machines.
109-
interleave_cycle_length: Union[Optional[int], _MISSING] = MISSING
110-
interleave_block_length: Optional[int] = 16
111-
input_context: Optional[tf.distribute.InputContext] = None
112-
experimental_interleave_sort_fn: Optional[InterleaveSortFn] = None
114+
interleave_cycle_length: int | None | _MISSING = MISSING
115+
interleave_block_length: int | None = 16
116+
input_context: tf.distribute.InputContext | None = None
117+
experimental_interleave_sort_fn: InterleaveSortFn | None = None
113118
skip_prefetch: bool = False
114-
num_parallel_calls_for_decode: Optional[int] = None
119+
num_parallel_calls_for_decode: int | None = None
115120
# Cast to an `int`. `__post_init__` will ensure the type invariant.
116-
num_parallel_calls_for_interleave_files: Optional[int] = cast(int, MISSING)
121+
num_parallel_calls_for_interleave_files: int | None = cast(int, MISSING)
117122
enable_ordering_guard: bool = True
118123
assert_cardinality: bool = True
119-
override_buffer_size: Optional[int] = None
124+
override_buffer_size: int | None = None
125+
file_format: str | file_adapters.FileFormat | None = None
120126

121127
def __post_init__(self):
122128
self.options = self.options or tf.data.Options()
123129
if self.num_parallel_calls_for_decode is None:
124130
self.num_parallel_calls_for_decode = tf.data.AUTOTUNE
125131
if self.num_parallel_calls_for_interleave_files == MISSING:
126132
self.num_parallel_calls_for_interleave_files = tf.data.AUTOTUNE
133+
if isinstance(self.file_format, str):
134+
self.file_format = file_adapters.FileFormat.from_value(self.file_format)
135+
136+
def replace(self, **kwargs) -> ReadConfig:
137+
return dataclasses.replace(self, **kwargs)

0 commit comments

Comments
 (0)