Skip to content

Commit 379a55d

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
Add a method to get a builder config by name and version.
PiperOrigin-RevId: 686429932
1 parent 5e83418 commit 379a55d

File tree

2 files changed

+98
-42
lines changed

2 files changed

+98
-42
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 55 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import json
2727
import os
2828
import sys
29-
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union
29+
from typing import Any, ClassVar, Type
3030

3131
from absl import logging
3232
from etils import epy
@@ -69,7 +69,7 @@
6969
ListOrTreeOrElem = type_utils.ListOrTreeOrElem
7070
Tree = type_utils.Tree
7171
TreeDict = type_utils.TreeDict
72-
VersionOrStr = Union[utils.Version, str]
72+
VersionOrStr = utils.Version | str
7373

7474
FORCE_REDOWNLOAD = download.GenerateMode.FORCE_REDOWNLOAD
7575
REUSE_CACHE_IF_EXISTS = download.GenerateMode.REUSE_CACHE_IF_EXISTS
@@ -108,7 +108,7 @@ class BuilderConfig:
108108

109109
name: str
110110
version: VersionOrStr | None = None
111-
release_notes: Dict[str, str] | None = None
111+
release_notes: dict[str, str] | None = None
112112
supported_versions: list[VersionOrStr] = dataclasses.field(
113113
default_factory=list
114114
)
@@ -192,12 +192,12 @@ class DatasetBuilder(registered.RegisteredDataset):
192192
"""
193193

194194
# Semantic version of the dataset (ex: tfds.core.Version('1.2.0'))
195-
VERSION: Optional[utils.Version] = None
195+
VERSION: utils.Version | None = None
196196

197197
# Release notes
198198
# Metadata only used for documentation. Should be a dict[version,description]
199199
# Multi-lines are automatically dedent
200-
RELEASE_NOTES: ClassVar[Dict[str, str]] = {}
200+
RELEASE_NOTES: ClassVar[dict[str, str]] = {}
201201

202202
# List dataset versions which can be loaded using current code.
203203
# Data can only be prepared with canonical VERSION or above.
@@ -209,7 +209,7 @@ class DatasetBuilder(registered.RegisteredDataset):
209209
# Name of the builder config that should be used in case the user doesn't
210210
# specify a config when loading a dataset. If None, then the first config in
211211
# `BUILDER_CONFIGS` is used.
212-
DEFAULT_BUILDER_CONFIG_NAME: Optional[str] = None
212+
DEFAULT_BUILDER_CONFIG_NAME: str | None = None
213213

214214
# Must be set for datasets that use 'manual_dir' functionality - the ones
215215
# that require users to do additional steps to download the data
@@ -222,15 +222,15 @@ class DatasetBuilder(registered.RegisteredDataset):
222222

223223
# Optional max number of simultaneous downloads. Setting this value will
224224
# override download config settings if necessary.
225-
MAX_SIMULTANEOUS_DOWNLOADS: Optional[int] = None
225+
MAX_SIMULTANEOUS_DOWNLOADS: int | None = None
226226

227227
# If not set, pkg_dir_path is inferred. However, if user of class knows better
228228
# then this can be set directly before init, to avoid heuristic inferences.
229229
# Example: `imported_builder_cls` function in `registered.py` module sets it.
230-
pkg_dir_path: Optional[epath.Path] = None
230+
pkg_dir_path: epath.Path | None = None
231231

232232
# Holds information on versions and configs that should not be used.
233-
BLOCKED_VERSIONS: ClassVar[Optional[utils.BlockedVersions]] = None
233+
BLOCKED_VERSIONS: ClassVar[utils.BlockedVersions | None] = None
234234

235235
@classmethod
236236
def _get_pkg_dir_path(cls) -> epath.Path:
@@ -309,7 +309,7 @@ def __init__(
309309
@utils.classproperty
310310
@classmethod
311311
@utils.memoize()
312-
def code_path(cls) -> Optional[epath.Path]:
312+
def code_path(cls) -> epath.Path | None:
313313
"""Returns the path to the file where the Dataset class is located.
314314
315315
Note: As the code can be run inside zip file. The returned value is
@@ -373,7 +373,7 @@ def supported_versions(self):
373373
return self.SUPPORTED_VERSIONS
374374

375375
@functools.cached_property
376-
def versions(self) -> List[utils.Version]:
376+
def versions(self) -> list[utils.Version]:
377377
"""Versions (canonical + availables), in preference order."""
378378
return [
379379
utils.Version(v) if isinstance(v, str) else v
@@ -407,7 +407,7 @@ def version(self) -> utils.Version:
407407
return self._version
408408

409409
@property
410-
def release_notes(self) -> Dict[str, str]:
410+
def release_notes(self) -> dict[str, str]:
411411
if self.builder_config and self.builder_config.release_notes:
412412
return self.builder_config.release_notes
413413
else:
@@ -452,7 +452,7 @@ def data_path(self) -> epath.Path:
452452

453453
@utils.classproperty
454454
@classmethod
455-
def _checksums_path(cls) -> Optional[epath.Path]:
455+
def _checksums_path(cls) -> epath.Path | None:
456456
"""Returns the checksums path."""
457457
# Used:
458458
# * To load the checksums (in url_infos)
@@ -476,7 +476,7 @@ def _checksums_path(cls) -> Optional[epath.Path]:
476476
@utils.classproperty
477477
@classmethod
478478
@functools.lru_cache(maxsize=None)
479-
def url_infos(cls) -> Optional[Dict[str, download.checksums.UrlInfo]]:
479+
def url_infos(cls) -> dict[str, download.checksums.UrlInfo] | None:
480480
"""Load `UrlInfo` from the given path."""
481481
# Note: If the dataset is downloaded with `record_checksums=True`, urls
482482
# might be updated but `url_infos` won't as it is memoized.
@@ -516,13 +516,13 @@ def info(self) -> dataset_info.DatasetInfo:
516516

517517
@utils.classproperty
518518
@classmethod
519-
def default_builder_config(cls) -> Optional[BuilderConfig]:
519+
def default_builder_config(cls) -> BuilderConfig | None:
520520
return _get_default_config(
521521
builder_configs=cls.BUILDER_CONFIGS,
522522
default_config_name=cls.DEFAULT_BUILDER_CONFIG_NAME,
523523
)
524524

525-
def get_default_builder_config(self) -> Optional[BuilderConfig]:
525+
def get_default_builder_config(self) -> BuilderConfig | None:
526526
"""Returns the default builder config if there is one.
527527
528528
Note that for dataset builders that cannot use the `cls.BUILDER_CONFIGS`, we
@@ -539,7 +539,7 @@ def get_default_builder_config(self) -> Optional[BuilderConfig]:
539539

540540
def get_reference(
541541
self,
542-
namespace: Optional[str] = None,
542+
namespace: str | None = None,
543543
) -> naming.DatasetReference:
544544
"""Returns a reference to the dataset produced by this dataset builder.
545545
@@ -807,9 +807,9 @@ def _update_dataset_info(self) -> None:
807807
@tfds_logging.as_data_source()
808808
def as_data_source(
809809
self,
810-
split: Optional[Tree[splits_lib.SplitArg]] = None,
810+
split: Tree[splits_lib.SplitArg] | None = None,
811811
*,
812-
decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]] = None,
812+
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
813813
deserialize_method: decode.DeserializeMethod = decode.DeserializeMethod.DESERIALIZE_AND_DECODE,
814814
) -> ListOrTreeOrElem[Sequence[Any]]:
815815
"""Constructs an `ArrayRecordDataSource`.
@@ -818,7 +818,7 @@ def as_data_source(
818818
split: Which split of the data to load (e.g. `'train'`, `'test'`,
819819
`['train', 'test']`, `'train[80%:]'`,...). See our [split API
820820
guide](https://www.tensorflow.org/datasets/splits). If `None`, will
821-
return all splits in a `Dict[Split, Sequence]`.
821+
return all splits in a `dict[Split, Sequence]`.
822822
decoders: Nested dict of `Decoder` objects which allow to customize the
823823
decoding. The structure should match the feature structure, but only
824824
customized feature keys need to be present. See [the
@@ -913,12 +913,12 @@ def build_single_data_source(split: str) -> Sequence[Any]:
913913
@tfds_logging.as_dataset()
914914
def as_dataset(
915915
self,
916-
split: Optional[Tree[splits_lib.SplitArg]] = None,
916+
split: Tree[splits_lib.SplitArg] | None = None,
917917
*,
918-
batch_size: Optional[int] = None,
918+
batch_size: int | None = None,
919919
shuffle_files: bool = False,
920-
decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]] = None,
921-
read_config: Optional[read_config_lib.ReadConfig] = None,
920+
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
921+
read_config: read_config_lib.ReadConfig | None = None,
922922
as_supervised: bool = False,
923923
):
924924
# pylint: disable=line-too-long
@@ -1029,9 +1029,9 @@ def as_dataset(
10291029
def _build_single_dataset(
10301030
self,
10311031
split: splits_lib.Split,
1032-
batch_size: Optional[int],
1032+
batch_size: int | None,
10331033
shuffle_files: bool,
1034-
decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]],
1034+
decoders: TreeDict[decode.partial_decode.DecoderArg] | None,
10351035
read_config: read_config_lib.ReadConfig,
10361036
as_supervised: bool,
10371037
) -> tf.data.Dataset:
@@ -1064,7 +1064,7 @@ def _build_single_dataset(
10641064
"structure."
10651065
)
10661066

1067-
def lookup_nest(features: Dict[str, Any]) -> Tuple[Any, ...]:
1067+
def lookup_nest(features: dict[str, Any]) -> tuple[Any, ...]:
10681068
"""Converts `features` to the structure described by `supervised_keys`.
10691069
10701070
Note that there is currently no way to access features in nested
@@ -1208,7 +1208,7 @@ def _info(self) -> dataset_info.DatasetInfo:
12081208
def _download_and_prepare(
12091209
self,
12101210
dl_manager: download.DownloadManager,
1211-
download_config: Optional[download.DownloadConfig] = None,
1211+
download_config: download.DownloadConfig | None = None,
12121212
) -> None:
12131213
"""Downloads and prepares dataset for reading.
12141214
@@ -1228,8 +1228,8 @@ def _download_and_prepare(
12281228
def _as_dataset(
12291229
self,
12301230
split: splits_lib.Split,
1231-
decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]] = None,
1232-
read_config: Optional[read_config_lib.ReadConfig] = None,
1231+
decoders: TreeDict[decode.partial_decode.DecoderArg] | None = None,
1232+
read_config: read_config_lib.ReadConfig | None = None,
12331233
shuffle_files: bool = False,
12341234
) -> tf.data.Dataset:
12351235
"""Constructs a `tf.data.Dataset`.
@@ -1313,7 +1313,7 @@ def _make_download_manager(
13131313
@utils.docs.do_not_doc_in_subclasses
13141314
@utils.classproperty
13151315
@classmethod
1316-
def builder_config_cls(cls) -> Optional[type[BuilderConfig]]:
1316+
def builder_config_cls(cls) -> type[BuilderConfig] | None:
13171317
"""Returns the builder config class."""
13181318
if not cls.BUILDER_CONFIGS:
13191319
return None
@@ -1328,7 +1328,7 @@ def builder_config_cls(cls) -> Optional[type[BuilderConfig]]:
13281328
return builder_cls
13291329

13301330
@property
1331-
def builder_config(self) -> Optional[Any]:
1331+
def builder_config(self) -> Any | None:
13321332
"""`tfds.core.BuilderConfig` for this builder."""
13331333
return self._builder_config
13341334

@@ -1410,6 +1410,19 @@ def builder_configs(cls) -> dict[str, BuilderConfig]:
14101410
)
14111411
return config_dict
14121412

1413+
@classmethod
1414+
def get_builder_config(
1415+
cls, name: str, version: str | utils.Version | None = None
1416+
) -> BuilderConfig | None:
1417+
"""Returns the builder config with the given name and version."""
1418+
if version is not None:
1419+
name_with_version = f"{name}:{version}"
1420+
if builder_config := cls.builder_configs.get(name_with_version):
1421+
return builder_config
1422+
if builder_config := cls.builder_configs.get(name):
1423+
return builder_config
1424+
return None
1425+
14131426
def _get_filename_template(
14141427
self, split_name: str
14151428
) -> naming.ShardedFileTemplate:
@@ -1437,7 +1450,7 @@ class FileReaderBuilder(DatasetBuilder):
14371450
def __init__(
14381451
self,
14391452
*,
1440-
file_format: Union[None, str, file_adapters.FileFormat] = None,
1453+
file_format: str | file_adapters.FileFormat | None = None,
14411454
**kwargs: Any,
14421455
):
14431456
"""Initializes an instance of FileReaderBuilder.
@@ -1460,7 +1473,7 @@ def _example_specs(self):
14601473
def _as_dataset( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
14611474
self,
14621475
split: splits_lib.Split,
1463-
decoders: Optional[TreeDict[decode.partial_decode.DecoderArg]],
1476+
decoders: TreeDict[decode.partial_decode.DecoderArg] | None,
14641477
read_config: read_config_lib.ReadConfig,
14651478
shuffle_files: bool,
14661479
) -> tf.data.Dataset:
@@ -1508,7 +1521,7 @@ class GeneratorBasedBuilder(FileReaderBuilder):
15081521
def _split_generators(
15091522
self,
15101523
dl_manager: download.DownloadManager,
1511-
) -> Dict[splits_lib.Split, split_builder_lib.SplitGenerator]:
1524+
) -> dict[splits_lib.Split, split_builder_lib.SplitGenerator]:
15121525
"""Downloads the data and returns dataset splits with associated examples.
15131526
15141527
Example:
@@ -1743,7 +1756,7 @@ def _download_and_prepare( # pytype: disable=signature-mismatch # overriding-p
17431756
self.info.set_splits(split_dict)
17441757

17451758
def read_text_file(
1746-
self, filename: epath.PathLike, encoding: Optional[str] = None
1759+
self, filename: epath.PathLike, encoding: str | None = None
17471760
) -> str:
17481761
"""Returns the text in the given file and records the lineage."""
17491762
filename = epath.Path(filename)
@@ -1775,7 +1788,7 @@ def read_tfrecord_as_dataset(
17751788

17761789
def read_tfrecord_as_examples(
17771790
self,
1778-
filenames: Union[str, Sequence[str]],
1791+
filenames: str | Sequence[str],
17791792
compression_type: str | None = None,
17801793
num_parallel_reads: int | None = None,
17811794
) -> Iterator[tf.train.Example]:
@@ -1932,9 +1945,9 @@ def _check_split_names(split_names: Iterable[str]) -> None:
19321945

19331946

19341947
def _get_default_config(
1935-
builder_configs: List[BuilderConfig],
1936-
default_config_name: Optional[str],
1937-
) -> Optional[BuilderConfig]:
1948+
builder_configs: list[BuilderConfig],
1949+
default_config_name: str | None,
1950+
) -> BuilderConfig | None:
19381951
"""Returns the default config from the given builder configs.
19391952
19401953
Arguments:
@@ -1995,8 +2008,8 @@ def load_default_config_name(builder_dir: epath.Path) -> str | None:
19952008

19962009

19972010
def canonical_version_for_config(
1998-
instance_or_cls: Union[DatasetBuilder, Type[DatasetBuilder]],
1999-
config: Optional[BuilderConfig] = None,
2011+
instance_or_cls: DatasetBuilder | Type[DatasetBuilder],
2012+
config: BuilderConfig | None = None,
20002013
) -> utils.Version:
20012014
"""Get the canonical version for the given config.
20022015

tensorflow_datasets/core/dataset_builder_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,49 @@ def test_builder_configs_configs_with_multiple_versions(self):
453453
set(DummyDatasetWithVersionedConfigs.builder_configs.keys()),
454454
)
455455

456+
def test_get_builder_config(self):
457+
plus1 = DummyDatasetWithConfigs.get_builder_config("plus1")
458+
self.assertEqual(plus1.name, "plus1")
459+
plus2 = DummyDatasetWithConfigs.get_builder_config("plus2")
460+
self.assertEqual(plus2.name, "plus2")
461+
462+
plus1_001 = DummyDatasetWithConfigs.get_builder_config(
463+
"plus1", version="0.0.1"
464+
)
465+
self.assertEqual(plus1_001.name, "plus1")
466+
self.assertEqual(str(plus1_001.version), "0.0.1")
467+
468+
plus2_002 = DummyDatasetWithConfigs.get_builder_config(
469+
"plus2", version="0.0.2"
470+
)
471+
self.assertEqual(plus2_002.name, "plus2")
472+
self.assertEqual(str(plus2_002.version), "0.0.2")
473+
474+
self.assertIsNone(
475+
DummyDatasetWithConfigs.get_builder_config(
476+
"i_dont_exist", version="0.0.1"
477+
)
478+
)
479+
480+
# DummyDatasetWithVersionedConfigs
481+
cfg1_001 = DummyDatasetWithVersionedConfigs.get_builder_config(
482+
"cfg1", version="0.0.1"
483+
)
484+
self.assertEqual(cfg1_001.name, "cfg1")
485+
self.assertEqual(str(cfg1_001.version), "0.0.1")
486+
487+
cfg1_002 = DummyDatasetWithVersionedConfigs.get_builder_config(
488+
"cfg1", version="0.0.2"
489+
)
490+
self.assertEqual(cfg1_002.name, "cfg1")
491+
self.assertEqual(str(cfg1_002.version), "0.0.2")
492+
493+
self.assertIsNone(
494+
DummyDatasetWithVersionedConfigs.get_builder_config(
495+
"cfg1", version="0.0.3"
496+
)
497+
)
498+
456499
def test_is_blocked(self):
457500
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
458501
tmp_dir = epath.Path(tmp_dir)

0 commit comments

Comments
 (0)