17
17
18
18
from __future__ import annotations
19
19
20
- from collections .abc import Sequence
20
+ from collections .abc import Iterable , Iterator , Mapping , Sequence
21
21
import dataclasses
22
22
import difflib
23
- import json
24
23
import posixpath
25
24
import re
26
25
import textwrap
27
26
import typing
28
- from typing import Any , Callable , Dict , Iterable , Iterator , List , Mapping , Optional , Type
27
+ from typing import Any , Callable , Optional , Type
29
28
30
29
from absl import logging
31
30
from etils import epath
32
31
from tensorflow_datasets .core import community
33
- from tensorflow_datasets .core import constants
34
32
from tensorflow_datasets .core import dataset_builder
35
33
from tensorflow_datasets .core import dataset_collection_builder
36
34
from tensorflow_datasets .core import decode
40
38
from tensorflow_datasets .core import read_only_builder
41
39
from tensorflow_datasets .core import registered
42
40
from tensorflow_datasets .core import splits as splits_lib
43
- from tensorflow_datasets .core import utils
44
41
from tensorflow_datasets .core import visibility
45
42
from tensorflow_datasets .core .dataset_builders import huggingface_dataset_builder # pylint:disable=unused-import
46
43
from tensorflow_datasets .core .download import util
49
46
from tensorflow_datasets .core .utils import py_utils
50
47
from tensorflow_datasets .core .utils import read_config as read_config_lib
51
48
from tensorflow_datasets .core .utils import type_utils
52
- from tensorflow_datasets .core .utils import version
49
+ from tensorflow_datasets .core .utils import version as version_lib
53
50
from tensorflow_datasets .core .utils .lazy_imports_utils import tensorflow as tf
54
51
55
52
# pylint: disable=logging-format-interpolation
74
71
def list_builders (
75
72
* ,
76
73
with_community_datasets : bool = True ,
77
- ) -> List [str ]:
74
+ ) -> list [str ]:
78
75
"""Returns the string names of all `tfds.core.DatasetBuilder`s."""
79
76
datasets = registered .list_imported_builders ()
80
77
if with_community_datasets :
@@ -83,7 +80,7 @@ def list_builders(
83
80
return datasets
84
81
85
82
86
- def list_dataset_collections () -> List [str ]:
83
+ def list_dataset_collections () -> list [str ]:
87
84
"""Returns the string names of all `tfds.core.DatasetCollectionBuilder`s."""
88
85
collections = registered .list_imported_dataset_collections ()
89
86
return collections
@@ -124,7 +121,7 @@ def builder_cls(name: str) -> Type[dataset_builder.DatasetBuilder]:
124
121
cls = typing .cast (Type [dataset_builder .DatasetBuilder ], cls )
125
122
return cls
126
123
except registered .DatasetNotFoundError :
127
- _add_list_builders_context (name = ds_name ) # pytype: disable=bad-return-type
124
+ _add_list_builders_context (name = ds_name )
128
125
raise
129
126
130
127
@@ -173,6 +170,9 @@ def builder(
173
170
name , builder_kwargs = naming .parse_builder_name_kwargs (
174
171
name , ** builder_kwargs
175
172
)
173
+ # Make sure that `data_dir` is not set to an empty string or None.
174
+ if 'data_dir' in builder_kwargs and not builder_kwargs ['data_dir' ]:
175
+ builder_kwargs .pop ('data_dir' )
176
176
177
177
def get_dataset_repr () -> str :
178
178
return f'dataset "{ name } ", builder_kwargs "{ builder_kwargs } "'
@@ -263,7 +263,7 @@ class DatasetCollectionLoader:
263
263
264
264
collection : dataset_collection_builder .DatasetCollection
265
265
requested_version : Optional [str ] = None
266
- loader_kwargs : Optional [ Dict [ str , Any ]] = None
266
+ loader_kwargs : dict [ str , Any ] | None = None
267
267
268
268
def __post_init__ (self ):
269
269
self .datasets = self .collection .get_collection (self .requested_version )
@@ -298,14 +298,14 @@ def get_dataset_info(self, dataset_name: str):
298
298
)
299
299
return info
300
300
301
- def set_loader_kwargs (self , loader_kwargs : Dict [str , Any ]):
301
+ def set_loader_kwargs (self , loader_kwargs : dict [str , Any ]):
302
302
self .loader_kwargs = loader_kwargs
303
303
304
304
def load_dataset (
305
305
self ,
306
306
dataset : str ,
307
307
split : Optional [Tree [splits_lib .SplitArg ]] = None ,
308
- loader_kwargs : Optional [ Dict [ str , Any ]] = None ,
308
+ loader_kwargs : dict [ str , Any ] | None = None ,
309
309
) -> Mapping [str , tf .data .Dataset ]:
310
310
"""Loads the named dataset from a dataset collection by calling `tfds.load`.
311
311
@@ -388,7 +388,7 @@ def load_datasets(
388
388
self ,
389
389
datasets : Iterable [str ],
390
390
split : Optional [Tree [splits_lib .SplitArg ]] = None ,
391
- loader_kwargs : Optional [ Dict [ str , Any ]] = None ,
391
+ loader_kwargs : dict [ str , Any ] | None = None ,
392
392
) -> Mapping [str , Mapping [str , tf .data .Dataset ]]:
393
393
"""Loads a number of datasets from the dataset collection.
394
394
@@ -418,7 +418,7 @@ def load_datasets(
418
418
def load_all_datasets (
419
419
self ,
420
420
split : Optional [Tree [splits_lib .SplitArg ]] = None ,
421
- loader_kwargs : Optional [ Dict [ str , Any ]] = None ,
421
+ loader_kwargs : dict [ str , Any ] | None = None ,
422
422
) -> Mapping [str , Mapping [str , tf .data .Dataset ]]:
423
423
"""Loads all datasets of a collection.
424
424
@@ -440,7 +440,7 @@ def load_all_datasets(
440
440
@tfds_logging .dataset_collection ()
441
441
def dataset_collection (
442
442
name : str ,
443
- loader_kwargs : Optional [Dict [str , Any ]] = None ,
443
+ loader_kwargs : Optional [dict [str , Any ]] = None ,
444
444
) -> DatasetCollectionLoader :
445
445
"""Instantiates a DatasetCollectionLoader.
446
446
@@ -500,7 +500,7 @@ def _fetch_builder(
500
500
def _download_and_prepare_builder (
501
501
dbuilder : dataset_builder .DatasetBuilder ,
502
502
download : bool ,
503
- download_and_prepare_kwargs : Optional [Dict [str , Any ]],
503
+ download_and_prepare_kwargs : Optional [dict [str , Any ]],
504
504
) -> None :
505
505
"""Downloads and prepares the dataset builder if necessary."""
506
506
if dbuilder .is_prepared ():
@@ -594,7 +594,7 @@ def load(
594
594
split: Which split of the data to load (e.g. `'train'`, `'test'`, `['train',
595
595
'test']`, `'train[80%:]'`,...). See our [split API
596
596
guide](https://www.tensorflow.org/datasets/splits). If `None`, will return
597
- all splits in a `Dict [Split, tf.data.Dataset]`
597
+ all splits in a `dict [Split, tf.data.Dataset]`
598
598
data_dir: directory to read/write data. Defaults to the value of the
599
599
environment variable TFDS_DATA_DIR, if set, otherwise falls back to
600
600
'~/tensorflow_datasets'.
@@ -776,7 +776,7 @@ def data_source(
776
776
split: Which split of the data to load (e.g. `'train'`, `'test'`, `['train',
777
777
'test']`, `'train[80%:]'`,...). See our [split API
778
778
guide](https://www.tensorflow.org/datasets/splits). If `None`, will return
779
- all splits in a `Dict [Split, Sequence]`
779
+ all splits in a `dict [Split, Sequence]`
780
780
data_dir: directory to read/write data. Defaults to the value of the
781
781
environment variable TFDS_DATA_DIR, if set, otherwise falls back to
782
782
'~/tensorflow_datasets'.
@@ -832,11 +832,11 @@ def data_source(
832
832
833
833
834
834
def _get_all_versions (
835
- current_version : version .Version | None ,
836
- extra_versions : Iterable [version .Version ],
835
+ current_version : version_lib .Version | None ,
836
+ extra_versions : Iterable [version_lib .Version ],
837
837
current_version_only : bool ,
838
- ) -> Iterable [str ]:
839
- """Returns the list of all current versions."""
838
+ ) -> set [str ]:
839
+ """Returns the set of all current versions."""
840
840
# Merge current version with all extra versions
841
841
version_list = [current_version ] if current_version else []
842
842
if not current_version_only :
@@ -881,7 +881,7 @@ def _iter_full_names(current_version_only: bool) -> Iterator[str]:
881
881
yield full_name
882
882
883
883
884
- def list_full_names (current_version_only : bool = False ) -> List [str ]:
884
+ def list_full_names (current_version_only : bool = False ) -> list [str ]:
885
885
"""Lists all registered datasets full_names.
886
886
887
887
Args:
@@ -896,7 +896,7 @@ def list_full_names(current_version_only: bool = False) -> List[str]:
896
896
def single_full_names (
897
897
builder_name : str ,
898
898
current_version_only : bool = True ,
899
- ) -> List [str ]:
899
+ ) -> list [str ]:
900
900
"""Returns the list `['ds/c0/v0',...]` or `['ds/v']` for a single builder."""
901
901
return sorted (
902
902
_iter_single_full_names (
0 commit comments