Skip to content

Commit e1220a9

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Move get_dataset_dir from read_only_builder.py to file_utils.py:
- Do not use recursion in `_find_builder_dir_single_dir`. - Fix types in `read_only_builder.py`. - Use `constants` in tests. PiperOrigin-RevId: 683116695
1 parent 5e7d183 commit e1220a9

File tree

5 files changed

+292
-248
lines changed

5 files changed

+292
-248
lines changed

tensorflow_datasets/core/read_only_builder.py

Lines changed: 66 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ def builder_from_directory(
169169
above.
170170
171171
Args:
172-
builder_dir: `str`, path of the directory containing the dataset to read (
173-
e.g. `~/tensorflow_datasets/mnist/3.0.0/`).
172+
builder_dir: Path of the directory containing the dataset to read ( e.g.
173+
`~/tensorflow_datasets/mnist/3.0.0/`).
174174
175175
Returns:
176176
builder: `tfds.core.DatasetBuilder`, builder for dataset at the given path.
@@ -311,7 +311,7 @@ def builder_from_files(
311311
return builder_from_directory(builder_dir)
312312

313313

314-
def _find_builder_dir(name: str, **builder_kwargs: Any) -> str | None:
314+
def _find_builder_dir(name: str, **builder_kwargs: Any) -> epath.Path | None:
315315
"""Search whether the given dataset is present on disk and return its path.
316316
317317
Note:
@@ -363,13 +363,13 @@ def _find_builder_dir(name: str, **builder_kwargs: Any) -> str | None:
363363
return None
364364

365365
# Search the dataset across all registered data_dirs
366-
all_builder_dirs = set()
366+
all_builder_dirs: set[epath.Path] = set()
367367
all_data_dirs = set(file_utils.list_data_dirs(given_data_dir=data_dir))
368368
find_builder_fn = functools.partial(
369369
_find_builder_dir_single_dir,
370370
builder_name=name.name,
371-
version_str=str(version) if version else None,
372371
config_name=config,
372+
version=version,
373373
)
374374
if len(all_data_dirs) <= 1:
375375
for current_data_dir in all_data_dirs:
@@ -398,14 +398,14 @@ def wrapped_find_builder_fn(data_dir):
398398

399399
# If the dataset root_dir exists, a common error is that the config name
400400
# was not specified. So we list the possible configs and display them.
401-
possible_configs = _list_possible_configs(name, all_data_dirs)
401+
possible_configs = _list_possible_configs(name.name, all_data_dirs)
402402
if possible_configs:
403-
configs = '\n\t- '.join([''] + list(possible_configs))
403+
configs_str = '\n\t- '.join([''] + possible_configs)
404404
error_msg = (
405405
f'However, a folder for "{name.name}" does exist. Is it possible that'
406406
' you specified the wrong config? You can add a config by replacing'
407407
f' `tfds.load({name.name})` by `tfds.load("{name.name}/my_config")`.'
408-
f' Possible configs are:{configs}\n'
408+
f' Possible configs are:{configs_str}\n'
409409
)
410410

411411
error_utils.add_context(error_msg)
@@ -431,32 +431,19 @@ def wrapped_find_builder_fn(data_dir):
431431

432432

433433
def _list_possible_configs(
434-
name: naming.DatasetName, all_data_dirs: set[epath.PathLike]
435-
) -> Sequence[str]:
434+
builder_name: str, all_data_dirs: set[epath.PathLike]
435+
) -> list[str]:
436436
configs = []
437437
for data_dir in all_data_dirs:
438-
root_dir = epath.Path(data_dir) / name.name
439-
if root_dir.exists():
440-
for path in root_dir.iterdir():
438+
builder_dir = epath.Path(data_dir) / builder_name
439+
if builder_dir.exists():
440+
for path in builder_dir.iterdir():
441441
if path.is_dir():
442442
configs.append(path.name)
443443
return configs
444444

445445

446-
def _get_dataset_dir(
447-
builder_dir: epath.Path,
448-
*,
449-
version_str: str,
450-
config_name: str | None = None,
451-
) -> epath.Path:
452-
"""Returns the path for the given dataset, config and version."""
453-
dataset_dir = builder_dir
454-
if config_name:
455-
dataset_dir = dataset_dir / config_name
456-
return dataset_dir / version_str
457-
458-
459-
def _contains_dataset(dataset_dir: epath.PathLike) -> bool:
446+
def _contains_dataset(dataset_dir: epath.Path) -> bool:
460447
try:
461448
return feature_lib.make_config_path(dataset_dir).exists()
462449
except (OSError, tf.errors.PermissionDeniedError):
@@ -467,49 +454,53 @@ def _find_builder_dir_single_dir(
467454
data_dir: epath.PathLike,
468455
builder_name: str,
469456
config_name: str | None = None,
470-
version_str: str | None = None,
471-
) -> str | None:
457+
version: version_lib.Version | str | None = None,
458+
) -> epath.Path | None:
472459
"""Same as `find_builder_dir` but requires explicit dir."""
473460

474-
builder_dir = epath.Path(data_dir) / builder_name
475-
476461
# If the version is specified, check if the dataset dir exists and return.
477-
if version_str and version_lib.Version.is_valid(version_str):
478-
dataset_dir = _get_dataset_dir(
479-
builder_dir=builder_dir,
480-
version_str=version_str,
462+
if version_lib.Version.is_valid(version):
463+
dataset_dir = file_utils.get_dataset_dir(
464+
data_dir=data_dir,
465+
builder_name=builder_name,
481466
config_name=config_name,
467+
version=version,
482468
)
483469
if _contains_dataset(dataset_dir):
484-
return os.fspath(dataset_dir)
470+
return dataset_dir
485471

486472
# If no config_name or an empty string was given, we try to find the default
487473
# config and load the dataset for that.
488474
if not config_name:
489-
default_config_name = _get_default_config_name(
490-
builder_dir=builder_dir, name=builder_name
475+
config_name = _get_default_config_name(
476+
data_dir=data_dir, builder_name=builder_name
491477
)
492-
if default_config_name:
493-
return _find_builder_dir_single_dir(
494-
builder_name=builder_name,
478+
if version_lib.Version.is_valid(version):
479+
dataset_dir = file_utils.get_dataset_dir(
495480
data_dir=data_dir,
496-
config_name=default_config_name,
497-
version_str=version_str,
481+
builder_name=builder_name,
482+
config_name=config_name,
483+
version=version,
498484
)
485+
if _contains_dataset(dataset_dir):
486+
return dataset_dir
499487

500488
# Dataset wasn't found, try to find a suitable available version.
501-
found_version_str = _get_version_str(
502-
builder_dir, config_name=config_name, requested_version=version_str
489+
found_version = _get_version(
490+
data_dir=data_dir,
491+
builder_name=builder_name,
492+
config_name=config_name,
493+
requested_version=version,
503494
)
504-
if found_version_str and (
505-
version_str is None or found_version_str != version_str
506-
):
507-
return _find_builder_dir_single_dir(
508-
builder_name=builder_name,
495+
if found_version and str(found_version) != version:
496+
dataset_dir = file_utils.get_dataset_dir(
509497
data_dir=data_dir,
498+
builder_name=builder_name,
510499
config_name=config_name,
511-
version_str=found_version_str,
500+
version=found_version,
512501
)
502+
if _contains_dataset(dataset_dir):
503+
return dataset_dir
513504

514505
# If no builder found, we populate the error_context with useful information
515506
# and return None.
@@ -521,16 +512,19 @@ def _find_builder_dir_single_dir(
521512

522513

523514
def _get_default_config_name(
524-
builder_dir: epath.Path,
525-
name: str,
515+
data_dir: epath.Path,
516+
builder_name: str,
526517
) -> str | None:
527518
"""Returns the default config of the given dataset, None if not found."""
519+
builder_dir = file_utils.get_dataset_dir(
520+
data_dir=data_dir, builder_name=builder_name
521+
)
528522
# Search for the DatasetBuilder generation code
529523
try:
530524
# Warning: The registered dataset may not match the files (e.g. if
531525
# the imported datasets has the same name as the generated files while
532526
# being 2 differents datasets)
533-
cls = registered.imported_builder_cls(name)
527+
cls = registered.imported_builder_cls(builder_name)
534528
cls = typing.cast(Type[dataset_builder.DatasetBuilder], cls)
535529
except registered.DatasetNotFoundError:
536530
pass
@@ -543,45 +537,43 @@ def _get_default_config_name(
543537
return cls.default_builder_config.name
544538

545539
# Otherwise, try to load default config from common metadata
546-
return dataset_builder.load_default_config_name(epath.Path(builder_dir))
540+
return dataset_builder.load_default_config_name(builder_dir)
547541

548542

549-
def _get_version_str(
550-
builder_dir: epath.Path,
551-
*,
543+
def _get_version(
544+
data_dir: epath.Path,
545+
builder_name: str,
552546
config_name: str | None = None,
553-
requested_version: str | None = None,
554-
) -> str | None:
555-
"""Returns the version name found in the directory.
547+
requested_version: version_lib.Version | str | None = None,
548+
) -> version_lib.Version | None:
549+
"""Returns the version name found in the builder directory.
556550
557551
Args:
558-
builder_dir: Directory containing the versions (`builder_dir/1.0.0/`,...)
559-
config_name: Optional name of the config that should be used. Will be
560-
ignored if it is an empty string.
552+
data_dir: Directory containing the builder.
553+
builder_name: Name of the builder.
554+
config_name: Name of the config.
561555
requested_version: Optional version to search (e.g. `1.0.0`, `2.*.*`,...)
562-
563-
Returns:
564-
version_str: The version directory name found in `builder_dir`.
565556
"""
566-
if config_name:
567-
builder_dir = builder_dir / config_name
568-
all_versions = version_lib.list_all_versions(os.fspath(builder_dir))
557+
config_dir = file_utils.get_dataset_dir(
558+
data_dir=data_dir, builder_name=builder_name, config_name=config_name
559+
)
560+
all_versions = version_lib.list_all_versions(config_dir)
569561
# Version not given, using the latest one.
570562
if not requested_version and all_versions:
571-
return str(all_versions[-1])
563+
return all_versions[-1]
572564
# Version given, return the highest version matching `requested_version`.
573565
for v in reversed(all_versions):
574566
if v.match(requested_version):
575-
return str(v)
567+
return v
576568
# Directory doesn't have version, or requested_version doesn't match
577569
if requested_version:
578570
error_msg = (
579571
f'No version matching the requested {requested_version} was '
580-
f'found in the builder directory: {builder_dir}.'
572+
f'found in the builder directory: {config_dir}.'
581573
)
582574
else:
583575
error_msg = (
584-
f"The builder directory {builder_dir} doesn't contain any versions."
576+
f"The builder directory {config_dir} doesn't contain any versions."
585577
)
586578
error_utils.add_context(error_msg)
587579
return None

0 commit comments

Comments
 (0)