Skip to content

Use simple_parsing for build cli command. #11103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 125 additions & 124 deletions tensorflow_datasets/scripts/cli/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,121 +17,121 @@

import argparse
from collections.abc import Iterator
import dataclasses
import functools
import importlib
import itertools
import json
import multiprocessing
import os
import typing
from typing import Any, Type

from absl import logging
import simple_parsing
import tensorflow_datasets as tfds
from tensorflow_datasets.scripts.cli import cli_utils

# pylint: disable=logging-fstring-interpolation


def register_subparser(parsers: argparse._SubParsersAction) -> None: # pylint: disable=protected-access
"""Add subparser for `build` command.

New flags should be added to `cli_utils` module.

Args:
parsers: The subparsers object to add the parser to.
@dataclasses.dataclass(frozen=True, kw_only=True)
class Args:
"""CLI arguments for building datasets.

Attributes:
positional_datasets: Name(s) of the dataset(s) to build. Default to current
dir. See https://www.tensorflow.org/datasets/cli for accepted values.
datasets: Datasets can also be provided as keyword argument.
debug: Debug & tests options. Use --pdb to enter post-mortem debugging mode
if an exception is raised.
paths: Path options.
generation: Generation options.
publishing: Publishing options.
automation: Automation options.
"""
build_parser = parsers.add_parser(
'build', help='Commands for downloading and preparing datasets.'
)
build_parser.add_argument(
'datasets', # Positional arguments
type=str,

positional_datasets: list[str] = simple_parsing.field(
positional=True,
nargs='*',
help=(
'Name(s) of the dataset(s) to build. Default to current dir. '
'See https://www.tensorflow.org/datasets/cli for accepted values.'
),
)
build_parser.add_argument( # Also accept keyword arguments
'--datasets',
type=str,
nargs='+',
dest='datasets_keyword',
help='Datasets can also be provided as keyword argument.',
default_factory=list,
# Need to explicitly set metavar for command-line help.
metavar='datasets',
)
datasets: list[str] = simple_parsing.field(nargs='*', default_factory=list)

cli_utils.add_debug_argument_group(build_parser)
cli_utils.add_path_argument_group(build_parser)
cli_utils.add_generation_argument_group(build_parser)
cli_utils.add_publish_argument_group(build_parser)

# **** Automation options ****
automation_group = build_parser.add_argument_group(
'Automation', description='Used by automated scripts.'
debug: cli_utils.DebugOptions = cli_utils.DebugOptions()
paths: cli_utils.PathOptions = simple_parsing.field(
default_factory=cli_utils.PathOptions
)
generation: cli_utils.GenerationOptions = simple_parsing.field(
default_factory=cli_utils.GenerationOptions
)
automation_group.add_argument(
'--exclude_datasets',
type=str,
help=(
'If set, generate all datasets except the one defined here. '
'Comma separated list of datasets to exclude. '
),
publishing: cli_utils.PublishingOptions = simple_parsing.field(
default_factory=cli_utils.PublishingOptions
)
automation_group.add_argument(
'--experimental_latest_version',
action='store_true',
help=(
'Build the latest Version(experiments=...) available rather than '
'default version.'
),
automation: cli_utils.AutomationOptions = simple_parsing.field(
default_factory=cli_utils.AutomationOptions
)

build_parser.set_defaults(subparser_fn=_build_datasets)
def execute(self) -> None:
"""Build the given datasets."""
# Eventually register additional datasets imports
if self.generation.imports:
list(
importlib.import_module(m) for m in self.generation.imports.split(',')
)

# Select datasets to generate
datasets = self.positional_datasets + self.datasets
if (
self.automation.exclude_datasets
): # Generate all datasets if `--exclude_datasets` set
if datasets:
raise ValueError("--exclude_datasets can't be used with `datasets`")
datasets = set(tfds.list_builders(with_community_datasets=False)) - set(
self.automation.exclude_datasets.split(',')
)
datasets = sorted(datasets) # `set` is not deterministic
else:
datasets = datasets or [''] # Empty string for default

# Import builder classes
builders_cls_and_kwargs = [
_get_builder_cls_and_kwargs(
dataset, has_imports=bool(self.generation.imports)
)
for dataset in datasets
]

# Parallelize datasets generation.
builders = itertools.chain(*(
_make_builders(self, builder_cls, builder_kwargs)
for (builder_cls, builder_kwargs) in builders_cls_and_kwargs
))
process_builder_fn = functools.partial(
_download if self.generation.download_only else _download_and_prepare,
self,
)

def _build_datasets(args: argparse.Namespace) -> None:
"""Build the given datasets."""
# Eventually register additional datasets imports
if args.imports:
list(importlib.import_module(m) for m in args.imports.split(','))
if self.generation.num_processes == 1:
for builder in builders:
process_builder_fn(builder)
else:
with multiprocessing.Pool(self.generation.num_processes) as pool:
pool.map(process_builder_fn, builders)

# Select datasets to generate
datasets = (args.datasets or []) + (args.datasets_keyword or [])
if args.exclude_datasets: # Generate all datasets if `--exclude_datasets` set
if datasets:
raise ValueError("--exclude_datasets can't be used with `datasets`")
datasets = set(tfds.list_builders(with_community_datasets=False)) - set(
args.exclude_datasets.split(',')
)
datasets = sorted(datasets) # `set` is not deterministic
else:
datasets = datasets or [''] # Empty string for default

# Import builder classes
builders_cls_and_kwargs = [
_get_builder_cls_and_kwargs(dataset, has_imports=bool(args.imports))
for dataset in datasets
]

# Parallelize datasets generation.
builders = itertools.chain(*(
_make_builders(args, builder_cls, builder_kwargs)
for (builder_cls, builder_kwargs) in builders_cls_and_kwargs
))
process_builder_fn = functools.partial(
_download if args.download_only else _download_and_prepare, args
)

if args.num_processes == 1:
for builder in builders:
process_builder_fn(builder)
else:
with multiprocessing.Pool(args.num_processes) as pool:
pool.map(process_builder_fn, builders)
def register_subparser(parsers: argparse._SubParsersAction) -> None: # pylint: disable=protected-access
"""Add subparser for `build` command."""
parser = parsers.add_parser(
'build', help='Commands for downloading and preparing datasets.'
)
parser = typing.cast(simple_parsing.ArgumentParser, parser)
parser.add_arguments(Args, dest='args')
parser.set_defaults(subparser_fn=lambda args: args.args.execute())


def _make_builders(
args: argparse.Namespace,
args: Args,
builder_cls: Type[tfds.core.DatasetBuilder],
builder_kwargs: dict[str, Any],
) -> Iterator[tfds.core.DatasetBuilder]:
Expand All @@ -146,7 +146,7 @@ def _make_builders(
Initialized dataset builders.
"""
# Eventually overwrite version
if args.experimental_latest_version:
if args.automation.experimental_latest_version:
if 'version' in builder_kwargs:
raise ValueError(
"Can't have both `--experimental_latest` and version set (`:1.0.0`)"
Expand All @@ -157,19 +157,19 @@ def _make_builders(
builder_kwargs['config'] = _get_config_name(
builder_cls=builder_cls,
config_kwarg=builder_kwargs.get('config'),
config_name=args.config,
config_idx=args.config_idx,
config_name=args.generation.config,
config_idx=args.generation.config_idx,
)

if args.file_format:
builder_kwargs['file_format'] = args.file_format
if args.generation.file_format:
builder_kwargs['file_format'] = args.generation.file_format

make_builder = functools.partial(
_make_builder,
builder_cls,
overwrite=args.overwrite,
fail_if_exists=args.fail_if_exists,
data_dir=args.data_dir,
overwrite=args.debug.overwrite,
fail_if_exists=args.debug.fail_if_exists,
data_dir=args.paths.data_dir,
**builder_kwargs,
)

Expand Down Expand Up @@ -203,7 +203,7 @@ def _get_builder_cls_and_kwargs(
if not has_imports:
path = _search_script_path(ds_to_build)
if path is not None:
logging.info(f'Loading dataset {ds_to_build} from path: {path}')
logging.info('Loading dataset %s from path: %s', ds_to_build, path)
# Dynamically load user dataset script
# When possible, load from the parent's parent, so module is named
# "foo.foo_dataset_builder".
Expand All @@ -228,7 +228,9 @@ def _get_builder_cls_and_kwargs(
name, builder_kwargs = tfds.core.naming.parse_builder_name_kwargs(ds_to_build)
builder_cls = tfds.builder_cls(str(name))
logging.info(
f'Loading dataset {ds_to_build} from imports: {builder_cls.__module__}'
'Loading dataset %s from imports: %s',
ds_to_build,
builder_cls.__module__,
)
return builder_cls, builder_kwargs

Expand Down Expand Up @@ -308,7 +310,7 @@ def _make_builder(


def _download(
args: argparse.Namespace,
args: Args,
builder: tfds.core.DatasetBuilder,
) -> None:
"""Downloads all files of the given builder."""
Expand All @@ -330,7 +332,7 @@ def _download(
if builder.MAX_SIMULTANEOUS_DOWNLOADS is not None:
max_simultaneous_downloads = builder.MAX_SIMULTANEOUS_DOWNLOADS

download_dir = args.download_dir or os.path.join(
download_dir = args.paths.download_dir or os.path.join(
builder._data_dir_root, 'downloads' # pylint: disable=protected-access
)
dl_manager = tfds.download.DownloadManager(
Expand All @@ -352,55 +354,55 @@ def _download(


def _download_and_prepare(
args: argparse.Namespace,
args: Args,
builder: tfds.core.DatasetBuilder,
) -> None:
"""Generate a single builder."""
cli_utils.download_and_prepare(
builder=builder,
download_config=_make_download_config(args, dataset_name=builder.name),
download_dir=args.download_dir,
publish_dir=args.publish_dir,
skip_if_published=args.skip_if_published,
overwrite=args.overwrite,
beam_pipeline_options=args.beam_pipeline_options,
nondeterministic_order=args.nondeterministic_order,
download_dir=args.paths.download_dir,
publish_dir=args.publishing.publish_dir,
skip_if_published=args.publishing.skip_if_published,
overwrite=args.debug.overwrite,
beam_pipeline_options=args.generation.beam_pipeline_options,
nondeterministic_order=args.generation.nondeterministic_order,
)


def _make_download_config(
args: argparse.Namespace,
args: Args,
dataset_name: str,
) -> tfds.download.DownloadConfig:
"""Generate the download and prepare configuration."""
# Load the download config
manual_dir = args.manual_dir
if args.add_name_to_manual_dir:
manual_dir = args.paths.manual_dir
if args.paths.add_name_to_manual_dir:
manual_dir = manual_dir / dataset_name

kwargs = {}
if args.max_shard_size_mb:
kwargs['max_shard_size'] = args.max_shard_size_mb << 20
if args.num_shards:
kwargs['num_shards'] = args.num_shards
if args.download_config:
kwargs.update(json.loads(args.download_config))
if args.generation.max_shard_size_mb:
kwargs['max_shard_size'] = args.generation.max_shard_size_mb << 20
if args.generation.num_shards:
kwargs['num_shards'] = args.generation.num_shards
if args.generation.download_config:
kwargs.update(json.loads(args.generation.download_config))

if 'download_mode' in kwargs:
kwargs['download_mode'] = tfds.download.GenerateMode(
kwargs['download_mode']
)
else:
kwargs['download_mode'] = tfds.download.GenerateMode.REUSE_DATASET_IF_EXISTS
if args.update_metadata_only:
if args.generation.update_metadata_only:
kwargs['download_mode'] = tfds.download.GenerateMode.UPDATE_DATASET_INFO

return tfds.download.DownloadConfig(
extract_dir=args.extract_dir,
extract_dir=args.paths.extract_dir,
manual_dir=manual_dir,
max_examples_per_split=args.max_examples_per_split,
register_checksums=args.register_checksums,
force_checksums_validation=args.force_checksums_validation,
max_examples_per_split=args.debug.max_examples_per_split,
register_checksums=args.generation.register_checksums,
force_checksums_validation=args.generation.force_checksums_validation,
**kwargs,
)

Expand Down Expand Up @@ -445,11 +447,10 @@ def _get_config_name(
else:
return config_name
elif config_idx is not None: # `--config_idx 123`
if config_idx > len(builder_cls.BUILDER_CONFIGS):
if config_idx >= len(builder_cls.BUILDER_CONFIGS):
raise ValueError(
f'--config_idx {config_idx} greater than number '
f'of configs {len(builder_cls.BUILDER_CONFIGS)} for '
f'{builder_cls.name}.'
f'--config_idx {config_idx} greater than number of configs '
f'{len(builder_cls.BUILDER_CONFIGS)} for {builder_cls.name}.'
)
else:
# Use `config.name` to avoid
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_datasets/scripts/cli/build_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ def test_download_only(build):
)
def test_make_download_config(args: str, download_config_kwargs):
args = main._parse_flags(f'tfds build x {args}'.split())
actual = build_lib._make_download_config(args, dataset_name='x')
cmd_args: build_lib.Args = args.args
actual = build_lib._make_download_config(cmd_args, dataset_name='x')
# Ignore the beam runner
actual = actual.replace(beam_runner=None)
expected = tfds.download.DownloadConfig(**download_config_kwargs)
Expand Down
Loading
Loading