Skip to content

Commit a15ea31

Browse files
marcenacpThe TensorFlow Datasets Authors
authored andcommitted
Do not overload the builder_kwargs with the file_format if we intend to read from files.
In OSS, we still want to overwrite the file_format when the user just writes `tfds.data_source('some_dataset')` without any builder_kwargs, because the default file format for OSS (tfrecord) doesn't support random access. Fixes: #5665 PiperOrigin-RevId: 686834897
1 parent 7e52fe4 commit a15ea31

File tree

2 files changed

+50
-41
lines changed

2 files changed

+50
-41
lines changed

tensorflow_datasets/core/load.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
import dataclasses
2222
import difflib
2323
import json
24-
import os
2524
import posixpath
2625
import re
2726
import textwrap
2827
import typing
29-
from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Type, Union
28+
from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Type
3029

3130
from absl import logging
31+
from etils import epath
3232
from tensorflow_datasets.core import community
3333
from tensorflow_datasets.core import constants
3434
from tensorflow_datasets.core import dataset_builder
@@ -225,7 +225,7 @@ def get_dataset_repr() -> str:
225225

226226

227227
def _try_load_from_files_first(
228-
cls: Optional[Type[dataset_builder.DatasetBuilder]],
228+
cls: Type[dataset_builder.DatasetBuilder] | None,
229229
**builder_kwargs: Any,
230230
) -> bool:
231231
"""Returns True if files should be used rather than code."""
@@ -487,8 +487,8 @@ def dataset_collection(
487487

488488
def _fetch_builder(
489489
name: str,
490-
data_dir: Union[None, str, os.PathLike], # pylint: disable=g-bare-generic
491-
builder_kwargs: Optional[Dict[str, Any]],
490+
data_dir: epath.PathLike | None,
491+
builder_kwargs: dict[str, Any] | None,
492492
try_gcs: bool,
493493
) -> dataset_builder.DatasetBuilder:
494494
"""Fetches the `tfds.core.DatasetBuilder` by name."""
@@ -521,18 +521,18 @@ def _download_and_prepare_builder(
521521
def load(
522522
name: str,
523523
*,
524-
split: Optional[Tree[splits_lib.SplitArg]] = None,
525-
data_dir: Union[None, str, os.PathLike] = None, # pylint: disable=g-bare-generic
526-
batch_size: Optional[int] = None,
524+
split: Tree[splits_lib.SplitArg] | None = None,
525+
data_dir: epath.PathLike | None = None,
526+
batch_size: int | None = None,
527527
shuffle_files: bool = False,
528528
download: bool = True,
529529
as_supervised: bool = False,
530-
decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]] = None,
531-
read_config: Optional[read_config_lib.ReadConfig] = None,
530+
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
531+
read_config: read_config_lib.ReadConfig | None = None,
532532
with_info: bool = False,
533-
builder_kwargs: Optional[Dict[str, Any]] = None,
534-
download_and_prepare_kwargs: Optional[Dict[str, Any]] = None,
535-
as_dataset_kwargs: Optional[Dict[str, Any]] = None,
533+
builder_kwargs: dict[str, Any] | None = None,
534+
download_and_prepare_kwargs: dict[str, Any] | None = None,
535+
as_dataset_kwargs: dict[str, Any] | None = None,
536536
try_gcs: bool = False,
537537
):
538538
# pylint: disable=line-too-long
@@ -677,37 +677,49 @@ def load(
677677

678678

679679
def _set_file_format_for_data_source(
680-
builder_kwargs: Optional[Dict[str, Any]],
681-
) -> Dict[str, Any]:
680+
data_dir: epath.PathLike | None,
681+
builder_kwargs: dict[str, Any] | None,
682+
) -> dict[str, Any]:
682683
"""Normalizes file format in builder_kwargs for `tfds.data_source`."""
683684
if builder_kwargs is None:
684685
builder_kwargs = {}
685-
file_format = builder_kwargs.get(
686-
'file_format', file_adapters.FileFormat.ARRAY_RECORD
687-
)
686+
# If the user specified a builder_kwargs or a data_dir, we don't want to
687+
# overwrite it.
688+
if builder_kwargs or data_dir:
689+
return builder_kwargs
690+
return {'file_format': file_adapters.FileFormat.ARRAY_RECORD}
691+
692+
693+
def _validate_file_format_for_data_source(
694+
builder_kwargs: dict[str, Any],
695+
) -> None:
696+
"""Validates whether the file format supports random access."""
697+
file_format = builder_kwargs.get('file_format')
698+
if not file_format:
699+
# We don't raise an error because we let TFDS handle the default (e.g.,
700+
# when loading a dataset from files that support random access).
701+
return
688702
file_format = file_adapters.FileFormat.from_value(file_format)
689-
if file_format != file_adapters.FileFormat.ARRAY_RECORD:
703+
if file_format not in file_adapters.FileFormat.with_random_access():
690704
raise NotImplementedError(
691705
f'No random access data source for file format {file_format}. Please,'
692706
' use `tfds.data_source(...,'
693707
' builder_kwargs={"file_format":'
694708
f' {file_adapters.FileFormat.ARRAY_RECORD}}})` instead.'
695709
)
696-
builder_kwargs['file_format'] = file_format
697-
return builder_kwargs
698710

699711

700712
@tfds_logging.data_source()
701713
def data_source(
702714
name: str,
703715
*,
704-
split: Optional[Tree[splits_lib.SplitArg]] = None,
705-
data_dir: Union[None, str, os.PathLike] = None, # pylint: disable=g-bare-generic
716+
split: Tree[splits_lib.SplitArg] | None = None,
717+
data_dir: epath.PathLike | None = None,
706718
download: bool = True,
707-
decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]] = None,
719+
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
708720
deserialize_method: decode.DeserializeMethod = decode.DeserializeMethod.DESERIALIZE_AND_DECODE,
709-
builder_kwargs: Optional[Dict[str, Any]] = None,
710-
download_and_prepare_kwargs: Optional[Dict[str, Any]] = None,
721+
builder_kwargs: dict[str, Any] | None = None,
722+
download_and_prepare_kwargs: dict[str, Any] | None = None,
711723
try_gcs: bool = False,
712724
) -> type_utils.ListOrTreeOrElem[Sequence[Any]]:
713725
"""Gets a data source from the named dataset.
@@ -805,7 +817,8 @@ def data_source(
805817
`Sequence` if `split`,
806818
`dict<key: tfds.Split, value: Sequence>` otherwise.
807819
""" # fmt:skip
808-
builder_kwargs = _set_file_format_for_data_source(builder_kwargs)
820+
builder_kwargs = _set_file_format_for_data_source(data_dir, builder_kwargs)
821+
_validate_file_format_for_data_source(builder_kwargs)
809822
dbuilder = _fetch_builder(
810823
name,
811824
data_dir,

tensorflow_datasets/core/load_test.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -140,27 +140,23 @@ def test_load_dataset_with_kwargs(
140140
assert loaded_dataset == expected
141141

142142

143-
@pytest.mark.parametrize(
144-
'builder_kwargs',
145-
[
146-
None,
147-
{'file_format': 'array_record'},
148-
{'file_format': file_adapters.FileFormat.ARRAY_RECORD},
149-
],
150-
)
151-
def test_data_source_defaults_to_array_record_format(
152-
builder_kwargs,
153-
):
143+
def test_data_source_defaults_to_array_record_format():
154144
with mock.patch.object(load, 'builder', autospec=True) as mock_builder:
155-
load.data_source(
156-
'mydataset', builder_kwargs=builder_kwargs,
157-
)
145+
load.data_source('mydataset', builder_kwargs=None)
158146
mock_builder.assert_called_with(
159147
'mydataset',
160148
data_dir=None,
161149
try_gcs=False,
162150
file_format=file_adapters.FileFormat.ARRAY_RECORD,
163151
)
152+
def test_data_source_keeps_format_if_builder_kwargs():
153+
with mock.patch.object(load, 'builder', autospec=True) as mock_builder:
154+
load.data_source('mydataset', data_dir='/foo/bar')
155+
mock_builder.assert_called_with(
156+
'mydataset',
157+
data_dir='/foo/bar',
158+
try_gcs=False,
159+
)
164160
@pytest.mark.parametrize(
165161
'file_format',
166162
['tfrecord', file_adapters.FileFormat.TFRECORD],

0 commit comments

Comments
 (0)