Skip to content

Commit 839e49e

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
If multiple data dirs are registered, then scan them concurrently to find whether they contain a dataset
PiperOrigin-RevId: 648360247
1 parent d9c0110 commit 839e49e

File tree

2 files changed

+30
-34
lines changed

2 files changed

+30
-34
lines changed

tensorflow_datasets/core/dataset_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class BuilderConfig:
119119
def from_dataset_info(
120120
cls,
121121
info_proto: dataset_info_pb2.DatasetInfo,
122-
) -> Optional["BuilderConfig"]:
122+
) -> BuilderConfig | None:
123123
"""Instantiates a BuilderConfig from the given proto.
124124
125125
Args:

tensorflow_datasets/core/read_only_builder.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717

1818
from __future__ import annotations
1919

20+
from collections.abc import Sequence
2021
import functools
2122
import os
2223
import typing
23-
from typing import Any, List, Optional, Type
24+
from typing import Any, Type
2425

2526
from etils import epy
2627
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
@@ -54,7 +55,7 @@ def __init__(
5455
self,
5556
builder_dir: epath.PathLike,
5657
*,
57-
info_proto: Optional[dataset_info_pb2.DatasetInfo] = None,
58+
info_proto: dataset_info_pb2.DatasetInfo | None = None,
5859
):
5960
"""Constructor.
6061
@@ -82,9 +83,7 @@ def __init__(
8283
# original source code.
8384
self.__module__ = info_proto.module_name
8485

85-
builder_config = dataset_builder.BuilderConfig.from_dataset_info(
86-
info_proto,
87-
)
86+
builder_config = dataset_builder.BuilderConfig.from_dataset_info(info_proto)
8887
# __init__ will call _build_data_dir, _create_builder_config,
8988
# _pick_version to set the data_dir, config, and version
9089
super().__init__(
@@ -158,9 +157,9 @@ def builder_from_directory(
158157

159158

160159
def builder_from_directories(
161-
builder_dirs: List[epath.PathLike],
160+
builder_dirs: Sequence[epath.PathLike],
162161
*,
163-
filetype_suffix: Optional[str] = None, # DEPRECATED
162+
filetype_suffix: str | None = None, # DEPRECATED
164163
) -> dataset_builder.DatasetBuilder:
165164
"""Loads a `tfds.core.DatasetBuilder` from the given generated dataset path.
166165
@@ -244,10 +243,8 @@ def builder_from_metadata(
244243
Returns:
245244
builder: `tfds.core.DatasetBuilder`, builder for dataset at the given path.
246245
"""
247-
return ReadOnlyBuilder(
248-
builder_dir=builder_dir,
249-
info_proto=info_proto,
250-
)
246+
builder = ReadOnlyBuilder(builder_dir=builder_dir, info_proto=info_proto)
247+
return builder
251248

252249

253250
@error_utils.reraise_with_context(registered.DatasetNotFoundError)
@@ -292,7 +289,7 @@ def builder_from_files(
292289
return builder_from_directory(builder_dir)
293290

294291

295-
def _find_builder_dir(name: str, **builder_kwargs: Any) -> Optional[str]:
292+
def _find_builder_dir(name: str, **builder_kwargs: Any) -> str | None:
296293
"""Search whether the given dataset is present on disk and return its path.
297294
298295
Note:
@@ -344,17 +341,17 @@ def _find_builder_dir(name: str, **builder_kwargs: Any) -> Optional[str]:
344341
return None
345342

346343
# Search the dataset across all registered data_dirs
347-
all_builder_dirs = []
348-
all_data_dirs = file_utils.list_data_dirs(given_data_dir=data_dir)
344+
all_builder_dirs = set()
345+
all_data_dirs = set(file_utils.list_data_dirs(given_data_dir=data_dir))
346+
find_builder_fn = functools.partial(
347+
_find_builder_dir_single_dir,
348+
builder_name=name.name,
349+
version_str=str(version) if version else None,
350+
config_name=config,
351+
)
349352
for current_data_dir in all_data_dirs:
350-
builder_dir = _find_builder_dir_single_dir(
351-
name.name,
352-
data_dir=current_data_dir,
353-
version_str=str(version) if version else None,
354-
config_name=config,
355-
)
356-
if builder_dir:
357-
all_builder_dirs.append(builder_dir)
353+
if builder_dir := find_builder_fn(data_dir=current_data_dir):
354+
all_builder_dirs.add(builder_dir)
358355

359356
if not all_builder_dirs:
360357
all_dirs_str = '\n\t- '.join([''] + [str(dir) for dir in all_data_dirs])
@@ -378,14 +375,14 @@ def _find_builder_dir(name: str, **builder_kwargs: Any) -> Optional[str]:
378375
'Please resolve the ambiguity by explicitly setting `data_dir=`.'
379376
)
380377

381-
return all_builder_dirs[0]
378+
return all_builder_dirs.pop()
382379

383380

384381
def _get_dataset_dir(
385382
builder_dir: epath.Path,
386383
*,
387384
version_str: str,
388-
config_name: Optional[str] = None,
385+
config_name: str | None = None,
389386
) -> epath.Path:
390387
"""Returns the path for the given dataset, config and version."""
391388
dataset_dir = builder_dir
@@ -402,12 +399,11 @@ def _contains_dataset(dataset_dir: epath.PathLike) -> bool:
402399

403400

404401
def _find_builder_dir_single_dir(
405-
builder_name: str,
406-
*,
407402
data_dir: epath.PathLike,
408-
config_name: Optional[str] = None,
409-
version_str: Optional[str] = None,
410-
) -> Optional[str]:
403+
builder_name: str,
404+
config_name: str | None = None,
405+
version_str: str | None = None,
406+
) -> str | None:
411407
"""Same as `find_builder_dir` but requires explicit dir."""
412408

413409
builder_dir = epath.Path(data_dir) / builder_name
@@ -462,7 +458,7 @@ def _find_builder_dir_single_dir(
462458
def _get_default_config_name(
463459
builder_dir: epath.Path,
464460
name: str,
465-
) -> Optional[str]:
461+
) -> str | None:
466462
"""Returns the default config of the given dataset, None if not found."""
467463
# Search for the DatasetBuilder generation code
468464
try:
@@ -488,9 +484,9 @@ def _get_default_config_name(
488484
def _get_version_str(
489485
builder_dir: epath.Path,
490486
*,
491-
config_name: Optional[str] = None,
492-
requested_version: Optional[str] = None,
493-
) -> Optional[str]:
487+
config_name: str | None = None,
488+
requested_version: str | None = None,
489+
) -> str | None:
494490
"""Returns the version name found in the directory.
495491
496492
Args:

0 commit comments

Comments
 (0)