Skip to content

Commit 22f60a5

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Internal change
PiperOrigin-RevId: 691774442
1 parent a38292f commit 22f60a5

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

tensorflow_datasets/core/load.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,18 @@
1717

1818
from __future__ import annotations
1919

20-
from collections.abc import Sequence
20+
from collections.abc import Iterable, Iterator, Mapping, Sequence
2121
import dataclasses
2222
import difflib
23-
import json
2423
import posixpath
2524
import re
2625
import textwrap
2726
import typing
28-
from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Type
27+
from typing import Any, Callable, Optional, Type
2928

3029
from absl import logging
3130
from etils import epath
3231
from tensorflow_datasets.core import community
33-
from tensorflow_datasets.core import constants
3432
from tensorflow_datasets.core import dataset_builder
3533
from tensorflow_datasets.core import dataset_collection_builder
3634
from tensorflow_datasets.core import decode
@@ -40,7 +38,6 @@
4038
from tensorflow_datasets.core import read_only_builder
4139
from tensorflow_datasets.core import registered
4240
from tensorflow_datasets.core import splits as splits_lib
43-
from tensorflow_datasets.core import utils
4441
from tensorflow_datasets.core import visibility
4542
from tensorflow_datasets.core.dataset_builders import huggingface_dataset_builder # pylint:disable=unused-import
4643
from tensorflow_datasets.core.download import util
@@ -49,7 +46,7 @@
4946
from tensorflow_datasets.core.utils import py_utils
5047
from tensorflow_datasets.core.utils import read_config as read_config_lib
5148
from tensorflow_datasets.core.utils import type_utils
52-
from tensorflow_datasets.core.utils import version
49+
from tensorflow_datasets.core.utils import version as version_lib
5350
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
5451

5552
# pylint: disable=logging-format-interpolation
@@ -74,7 +71,7 @@
7471
def list_builders(
7572
*,
7673
with_community_datasets: bool = True,
77-
) -> List[str]:
74+
) -> list[str]:
7875
"""Returns the string names of all `tfds.core.DatasetBuilder`s."""
7976
datasets = registered.list_imported_builders()
8077
if with_community_datasets:
@@ -83,7 +80,7 @@ def list_builders(
8380
return datasets
8481

8582

86-
def list_dataset_collections() -> List[str]:
83+
def list_dataset_collections() -> list[str]:
8784
"""Returns the string names of all `tfds.core.DatasetCollectionBuilder`s."""
8885
collections = registered.list_imported_dataset_collections()
8986
return collections
@@ -124,7 +121,7 @@ def builder_cls(name: str) -> Type[dataset_builder.DatasetBuilder]:
124121
cls = typing.cast(Type[dataset_builder.DatasetBuilder], cls)
125122
return cls
126123
except registered.DatasetNotFoundError:
127-
_add_list_builders_context(name=ds_name) # pytype: disable=bad-return-type
124+
_add_list_builders_context(name=ds_name)
128125
raise
129126

130127

@@ -173,6 +170,9 @@ def builder(
173170
name, builder_kwargs = naming.parse_builder_name_kwargs(
174171
name, **builder_kwargs
175172
)
173+
# Make sure that `data_dir` is not set to an empty string or None.
174+
if 'data_dir' in builder_kwargs and not builder_kwargs['data_dir']:
175+
builder_kwargs.pop('data_dir')
176176

177177
def get_dataset_repr() -> str:
178178
return f'dataset "{name}", builder_kwargs "{builder_kwargs}"'
@@ -263,7 +263,7 @@ class DatasetCollectionLoader:
263263

264264
collection: dataset_collection_builder.DatasetCollection
265265
requested_version: Optional[str] = None
266-
loader_kwargs: Optional[Dict[str, Any]] = None
266+
loader_kwargs: dict[str, Any] | None = None
267267

268268
def __post_init__(self):
269269
self.datasets = self.collection.get_collection(self.requested_version)
@@ -298,14 +298,14 @@ def get_dataset_info(self, dataset_name: str):
298298
)
299299
return info
300300

301-
def set_loader_kwargs(self, loader_kwargs: Dict[str, Any]):
301+
def set_loader_kwargs(self, loader_kwargs: dict[str, Any]):
302302
self.loader_kwargs = loader_kwargs
303303

304304
def load_dataset(
305305
self,
306306
dataset: str,
307307
split: Optional[Tree[splits_lib.SplitArg]] = None,
308-
loader_kwargs: Optional[Dict[str, Any]] = None,
308+
loader_kwargs: dict[str, Any] | None = None,
309309
) -> Mapping[str, tf.data.Dataset]:
310310
"""Loads the named dataset from a dataset collection by calling `tfds.load`.
311311
@@ -388,7 +388,7 @@ def load_datasets(
388388
self,
389389
datasets: Iterable[str],
390390
split: Optional[Tree[splits_lib.SplitArg]] = None,
391-
loader_kwargs: Optional[Dict[str, Any]] = None,
391+
loader_kwargs: dict[str, Any] | None = None,
392392
) -> Mapping[str, Mapping[str, tf.data.Dataset]]:
393393
"""Loads a number of datasets from the dataset collection.
394394
@@ -418,7 +418,7 @@ def load_datasets(
418418
def load_all_datasets(
419419
self,
420420
split: Optional[Tree[splits_lib.SplitArg]] = None,
421-
loader_kwargs: Optional[Dict[str, Any]] = None,
421+
loader_kwargs: dict[str, Any] | None = None,
422422
) -> Mapping[str, Mapping[str, tf.data.Dataset]]:
423423
"""Loads all datasets of a collection.
424424
@@ -440,7 +440,7 @@ def load_all_datasets(
440440
@tfds_logging.dataset_collection()
441441
def dataset_collection(
442442
name: str,
443-
loader_kwargs: Optional[Dict[str, Any]] = None,
443+
loader_kwargs: Optional[dict[str, Any]] = None,
444444
) -> DatasetCollectionLoader:
445445
"""Instantiates a DatasetCollectionLoader.
446446
@@ -500,7 +500,7 @@ def _fetch_builder(
500500
def _download_and_prepare_builder(
501501
dbuilder: dataset_builder.DatasetBuilder,
502502
download: bool,
503-
download_and_prepare_kwargs: Optional[Dict[str, Any]],
503+
download_and_prepare_kwargs: Optional[dict[str, Any]],
504504
) -> None:
505505
"""Downloads and prepares the dataset builder if necessary."""
506506
if dbuilder.is_prepared():
@@ -594,7 +594,7 @@ def load(
594594
split: Which split of the data to load (e.g. `'train'`, `'test'`, `['train',
595595
'test']`, `'train[80%:]'`,...). See our [split API
596596
guide](https://www.tensorflow.org/datasets/splits). If `None`, will return
597-
all splits in a `Dict[Split, tf.data.Dataset]`
597+
all splits in a `dict[Split, tf.data.Dataset]`
598598
data_dir: directory to read/write data. Defaults to the value of the
599599
environment variable TFDS_DATA_DIR, if set, otherwise falls back to
600600
'~/tensorflow_datasets'.
@@ -776,7 +776,7 @@ def data_source(
776776
split: Which split of the data to load (e.g. `'train'`, `'test'`, `['train',
777777
'test']`, `'train[80%:]'`,...). See our [split API
778778
guide](https://www.tensorflow.org/datasets/splits). If `None`, will return
779-
all splits in a `Dict[Split, Sequence]`
779+
all splits in a `dict[Split, Sequence]`
780780
data_dir: directory to read/write data. Defaults to the value of the
781781
environment variable TFDS_DATA_DIR, if set, otherwise falls back to
782782
'~/tensorflow_datasets'.
@@ -832,11 +832,11 @@ def data_source(
832832

833833

834834
def _get_all_versions(
835-
current_version: version.Version | None,
836-
extra_versions: Iterable[version.Version],
835+
current_version: version_lib.Version | None,
836+
extra_versions: Iterable[version_lib.Version],
837837
current_version_only: bool,
838-
) -> Iterable[str]:
839-
"""Returns the list of all current versions."""
838+
) -> set[str]:
839+
"""Returns the set of all current versions."""
840840
# Merge current version with all extra versions
841841
version_list = [current_version] if current_version else []
842842
if not current_version_only:
@@ -881,7 +881,7 @@ def _iter_full_names(current_version_only: bool) -> Iterator[str]:
881881
yield full_name
882882

883883

884-
def list_full_names(current_version_only: bool = False) -> List[str]:
884+
def list_full_names(current_version_only: bool = False) -> list[str]:
885885
"""Lists all registered datasets full_names.
886886
887887
Args:
@@ -896,7 +896,7 @@ def list_full_names(current_version_only: bool = False) -> List[str]:
896896
def single_full_names(
897897
builder_name: str,
898898
current_version_only: bool = True,
899-
) -> List[str]:
899+
) -> list[str]:
900900
"""Returns the list `['ds/c0/v0',...]` or `['ds/v']` for a single builder."""
901901
return sorted(
902902
_iter_single_full_names(

0 commit comments

Comments
 (0)