diff --git a/tensorflow_datasets/scripts/cli/build.py b/tensorflow_datasets/scripts/cli/build.py index 062d2f0b89d..36e21f3d469 100644 --- a/tensorflow_datasets/scripts/cli/build.py +++ b/tensorflow_datasets/scripts/cli/build.py @@ -17,121 +17,121 @@ import argparse from collections.abc import Iterator +import dataclasses import functools import importlib import itertools import json import multiprocessing import os +import typing from typing import Any, Type from absl import logging +import simple_parsing import tensorflow_datasets as tfds from tensorflow_datasets.scripts.cli import cli_utils -# pylint: disable=logging-fstring-interpolation - -def register_subparser(parsers: argparse._SubParsersAction) -> None: # pylint: disable=protected-access - """Add subparser for `build` command. - - New flags should be added to `cli_utils` module. - - Args: - parsers: The subparsers object to add the parser to. +@dataclasses.dataclass(frozen=True, kw_only=True) +class Args: + """CLI arguments for building datasets. + + Attributes: + positional_datasets: Name(s) of the dataset(s) to build. Default to current + dir. See https://www.tensorflow.org/datasets/cli for accepted values. + datasets: Datasets can also be provided as keyword argument. + debug: Debug & tests options. Use --pdb to enter post-mortem debugging mode + if an exception is raised. + paths: Path options. + generation: Generation options. + publishing: Publishing options. + automation: Automation options. """ - build_parser = parsers.add_parser( - 'build', help='Commands for downloading and preparing datasets.' - ) - build_parser.add_argument( - 'datasets', # Positional arguments - type=str, + + positional_datasets: list[str] = simple_parsing.field( + positional=True, nargs='*', - help=( - 'Name(s) of the dataset(s) to build. Default to current dir. ' - 'See https://www.tensorflow.org/datasets/cli for accepted values.' - ), - ) - build_parser.add_argument( # Also accept keyword arguments - '--datasets', - type=str, - nargs='+', - dest='datasets_keyword', - help='Datasets can also be provided as keyword argument.', + default_factory=list, + # Need to explicitly set metavar for command-line help. + metavar='datasets', ) + datasets: list[str] = simple_parsing.field(nargs='*', default_factory=list) - cli_utils.add_debug_argument_group(build_parser) - cli_utils.add_path_argument_group(build_parser) - cli_utils.add_generation_argument_group(build_parser) - cli_utils.add_publish_argument_group(build_parser) - - # **** Automation options **** - automation_group = build_parser.add_argument_group( - 'Automation', description='Used by automated scripts.' + debug: cli_utils.DebugOptions = cli_utils.DebugOptions() + paths: cli_utils.PathOptions = simple_parsing.field( + default_factory=cli_utils.PathOptions + ) + generation: cli_utils.GenerationOptions = simple_parsing.field( + default_factory=cli_utils.GenerationOptions ) - automation_group.add_argument( - '--exclude_datasets', - type=str, - help=( - 'If set, generate all datasets except the one defined here. ' - 'Comma separated list of datasets to exclude. ' - ), + publishing: cli_utils.PublishingOptions = simple_parsing.field( + default_factory=cli_utils.PublishingOptions ) - automation_group.add_argument( - '--experimental_latest_version', - action='store_true', - help=( - 'Build the latest Version(experiments=...) available rather than ' - 'default version.' - ), + automation: cli_utils.AutomationOptions = simple_parsing.field( + default_factory=cli_utils.AutomationOptions ) - build_parser.set_defaults(subparser_fn=_build_datasets) + def execute(self) -> None: + """Build the given datasets.""" + # Eventually register additional datasets imports + if self.generation.imports: + list( + importlib.import_module(m) for m in self.generation.imports.split(',') + ) + # Select datasets to generate + datasets = self.positional_datasets + self.datasets + if ( + self.automation.exclude_datasets + ): # Generate all datasets if `--exclude_datasets` set + if datasets: + raise ValueError("--exclude_datasets can't be used with `datasets`") + datasets = set(tfds.list_builders(with_community_datasets=False)) - set( + self.automation.exclude_datasets.split(',') + ) + datasets = sorted(datasets) # `set` is not deterministic + else: + datasets = datasets or [''] # Empty string for default + + # Import builder classes + builders_cls_and_kwargs = [ + _get_builder_cls_and_kwargs( + dataset, has_imports=bool(self.generation.imports) + ) + for dataset in datasets + ] + + # Parallelize datasets generation. + builders = itertools.chain(*( + _make_builders(self, builder_cls, builder_kwargs) + for (builder_cls, builder_kwargs) in builders_cls_and_kwargs + )) + process_builder_fn = functools.partial( + _download if self.generation.download_only else _download_and_prepare, + self, + ) -def _build_datasets(args: argparse.Namespace) -> None: - """Build the given datasets.""" - # Eventually register additional datasets imports - if args.imports: - list(importlib.import_module(m) for m in args.imports.split(',')) + if self.generation.num_processes == 1: + for builder in builders: + process_builder_fn(builder) + else: + with multiprocessing.Pool(self.generation.num_processes) as pool: + pool.map(process_builder_fn, builders) - # Select datasets to generate - datasets = (args.datasets or []) + (args.datasets_keyword or []) - if args.exclude_datasets: # Generate all datasets if `--exclude_datasets` set - if datasets: - raise ValueError("--exclude_datasets can't be used with `datasets`") - datasets = set(tfds.list_builders(with_community_datasets=False)) - set( - args.exclude_datasets.split(',') - ) - datasets = sorted(datasets) # `set` is not deterministic - else: - datasets = datasets or [''] # Empty string for default - - # Import builder classes - builders_cls_and_kwargs = [ - _get_builder_cls_and_kwargs(dataset, has_imports=bool(args.imports)) - for dataset in datasets - ] - - # Parallelize datasets generation. - builders = itertools.chain(*( - _make_builders(args, builder_cls, builder_kwargs) - for (builder_cls, builder_kwargs) in builders_cls_and_kwargs - )) - process_builder_fn = functools.partial( - _download if args.download_only else _download_and_prepare, args - ) - if args.num_processes == 1: - for builder in builders: - process_builder_fn(builder) - else: - with multiprocessing.Pool(args.num_processes) as pool: - pool.map(process_builder_fn, builders) +def register_subparser(parsers: argparse._SubParsersAction) -> None: # pylint: disable=protected-access + """Add subparser for `build` command.""" + parser = parsers.add_parser( + 'build', help='Commands for downloading and preparing datasets.' + ) + parser = typing.cast(simple_parsing.ArgumentParser, parser) + parser.add_arguments(Args, dest='args') + parser.set_defaults(subparser_fn=lambda args: args.args.execute()) def _make_builders( - args: argparse.Namespace, + args: Args, builder_cls: Type[tfds.core.DatasetBuilder], builder_kwargs: dict[str, Any], ) -> Iterator[tfds.core.DatasetBuilder]: @@ -146,7 +146,7 @@ def _make_builders( Initialized dataset builders. """ # Eventually overwrite version - if args.experimental_latest_version: + if args.automation.experimental_latest_version: if 'version' in builder_kwargs: raise ValueError( "Can't have both `--experimental_latest` and version set (`:1.0.0`)" @@ -157,19 +157,19 @@ def _make_builders( builder_kwargs['config'] = _get_config_name( builder_cls=builder_cls, config_kwarg=builder_kwargs.get('config'), - config_name=args.config, - config_idx=args.config_idx, + config_name=args.generation.config, + config_idx=args.generation.config_idx, ) - if args.file_format: - builder_kwargs['file_format'] = args.file_format + if args.generation.file_format: + builder_kwargs['file_format'] = args.generation.file_format make_builder = functools.partial( _make_builder, builder_cls, - overwrite=args.overwrite, - fail_if_exists=args.fail_if_exists, - data_dir=args.data_dir, + overwrite=args.debug.overwrite, + fail_if_exists=args.debug.fail_if_exists, + data_dir=args.paths.data_dir, **builder_kwargs, ) @@ -203,7 +203,7 @@ def _get_builder_cls_and_kwargs( if not has_imports: path = _search_script_path(ds_to_build) if path is not None: - logging.info(f'Loading dataset {ds_to_build} from path: {path}') + logging.info('Loading dataset %s from path: %s', ds_to_build, path) # Dynamically load user dataset script # When possible, load from the parent's parent, so module is named # "foo.foo_dataset_builder". @@ -228,7 +228,9 @@ def _get_builder_cls_and_kwargs( name, builder_kwargs = tfds.core.naming.parse_builder_name_kwargs(ds_to_build) builder_cls = tfds.builder_cls(str(name)) logging.info( - f'Loading dataset {ds_to_build} from imports: {builder_cls.__module__}' + 'Loading dataset %s from imports: %s', + ds_to_build, + builder_cls.__module__, ) return builder_cls, builder_kwargs @@ -308,7 +310,7 @@ def _make_builder( def _download( - args: argparse.Namespace, + args: Args, builder: tfds.core.DatasetBuilder, ) -> None: """Downloads all files of the given builder.""" @@ -330,7 +332,7 @@ def _download( if builder.MAX_SIMULTANEOUS_DOWNLOADS is not None: max_simultaneous_downloads = builder.MAX_SIMULTANEOUS_DOWNLOADS - download_dir = args.download_dir or os.path.join( + download_dir = args.paths.download_dir or os.path.join( builder._data_dir_root, 'downloads' # pylint: disable=protected-access ) dl_manager = tfds.download.DownloadManager( @@ -352,39 +354,39 @@ def _download( def _download_and_prepare( - args: argparse.Namespace, + args: Args, builder: tfds.core.DatasetBuilder, ) -> None: """Generate a single builder.""" cli_utils.download_and_prepare( builder=builder, download_config=_make_download_config(args, dataset_name=builder.name), - download_dir=args.download_dir, - publish_dir=args.publish_dir, - skip_if_published=args.skip_if_published, - overwrite=args.overwrite, - beam_pipeline_options=args.beam_pipeline_options, - nondeterministic_order=args.nondeterministic_order, + download_dir=args.paths.download_dir, + publish_dir=args.publishing.publish_dir, + skip_if_published=args.publishing.skip_if_published, + overwrite=args.debug.overwrite, + beam_pipeline_options=args.generation.beam_pipeline_options, + nondeterministic_order=args.generation.nondeterministic_order, ) def _make_download_config( - args: argparse.Namespace, + args: Args, dataset_name: str, ) -> tfds.download.DownloadConfig: """Generate the download and prepare configuration.""" # Load the download config - manual_dir = args.manual_dir - if args.add_name_to_manual_dir: + manual_dir = args.paths.manual_dir + if args.paths.add_name_to_manual_dir: manual_dir = manual_dir / dataset_name kwargs = {} - if args.max_shard_size_mb: - kwargs['max_shard_size'] = args.max_shard_size_mb << 20 - if args.num_shards: - kwargs['num_shards'] = args.num_shards - if args.download_config: - kwargs.update(json.loads(args.download_config)) + if args.generation.max_shard_size_mb: + kwargs['max_shard_size'] = args.generation.max_shard_size_mb << 20 + if args.generation.num_shards: + kwargs['num_shards'] = args.generation.num_shards + if args.generation.download_config: + kwargs.update(json.loads(args.generation.download_config)) if 'download_mode' in kwargs: kwargs['download_mode'] = tfds.download.GenerateMode( @@ -392,15 +394,15 @@ def _make_download_config( ) else: kwargs['download_mode'] = tfds.download.GenerateMode.REUSE_DATASET_IF_EXISTS - if args.update_metadata_only: + if args.generation.update_metadata_only: kwargs['download_mode'] = tfds.download.GenerateMode.UPDATE_DATASET_INFO return tfds.download.DownloadConfig( - extract_dir=args.extract_dir, + extract_dir=args.paths.extract_dir, manual_dir=manual_dir, - max_examples_per_split=args.max_examples_per_split, - register_checksums=args.register_checksums, - force_checksums_validation=args.force_checksums_validation, + max_examples_per_split=args.debug.max_examples_per_split, + register_checksums=args.generation.register_checksums, + force_checksums_validation=args.generation.force_checksums_validation, **kwargs, ) @@ -445,11 +447,10 @@ def _get_config_name( else: return config_name elif config_idx is not None: # `--config_idx 123` - if config_idx > len(builder_cls.BUILDER_CONFIGS): + if config_idx >= len(builder_cls.BUILDER_CONFIGS): raise ValueError( - f'--config_idx {config_idx} greater than number ' - f'of configs {len(builder_cls.BUILDER_CONFIGS)} for ' - f'{builder_cls.name}.' + f'--config_idx {config_idx} greater than number of configs ' + f'{len(builder_cls.BUILDER_CONFIGS)} for {builder_cls.name}.' ) else: # Use `config.name` to avoid diff --git a/tensorflow_datasets/scripts/cli/build_test.py b/tensorflow_datasets/scripts/cli/build_test.py index 97044d11f4a..cbbc68bd686 100644 --- a/tensorflow_datasets/scripts/cli/build_test.py +++ b/tensorflow_datasets/scripts/cli/build_test.py @@ -311,7 +311,8 @@ def test_download_only(build): ) def test_make_download_config(args: str, download_config_kwargs): args = main._parse_flags(f'tfds build x {args}'.split()) - actual = build_lib._make_download_config(args, dataset_name='x') + cmd_args: build_lib.Args = args.args + actual = build_lib._make_download_config(cmd_args, dataset_name='x') # Ignore the beam runner actual = actual.replace(beam_runner=None) expected = tfds.download.DownloadConfig(**download_config_kwargs) diff --git a/tensorflow_datasets/scripts/cli/cli_utils.py b/tensorflow_datasets/scripts/cli/cli_utils.py index 9b57afbb1f1..614a3eb7399 100644 --- a/tensorflow_datasets/scripts/cli/cli_utils.py +++ b/tensorflow_datasets/scripts/cli/cli_utils.py @@ -127,232 +127,137 @@ def __post_init__(self): self.ds_import = ds_import -def add_debug_argument_group(parser: argparse.ArgumentParser): - """Adds debug argument group to the parser.""" - debug_group = parser.add_argument_group( - 'Debug & tests', - description=( - '--pdb Enter post-mortem debugging mode if an exception is raised.' - ), - ) - debug_group.add_argument( - '--overwrite', - action='store_true', - help='Delete pre-existing dataset if it exists.', - ) - debug_group.add_argument( - '--fail_if_exists', - action='store_true', - default=False, - help='Fails the program if there is a pre-existing dataset.', - ) - debug_group.add_argument( - '--max_examples_per_split', - type=int, - nargs='?', - const=1, - help=( - 'When set, only generate the first X examples (default to 1), rather' - ' than the full dataset.If set to 0, only execute the' - ' `_split_generators` (which download the original data), but skip' - ' `_generator_examples`' - ), - ) +@dataclasses.dataclass(frozen=True, kw_only=True) +class DebugOptions: + """Debug & tests options. + Attributes: + overwrite: If True, delete pre-existing dataset if it exists. + fail_if_exists: If True, fails the program if there is a pre-existing + dataset. + max_examples_per_split: When set, only generate the first X examples + (default to 1), rather than the full dataset. If set to 0, only execute + the `_split_generators` (which download the original data), but skip + `_generator_examples`. + """ -def add_path_argument_group(parser: argparse.ArgumentParser): - """Adds path argument group to the parser.""" - path_group = parser.add_argument_group('Paths') - path_group.add_argument( - '--data_dir', - type=epath.Path, - default=epath.Path(constants.DATA_DIR), - help=( - 'Where to place datasets. Default to ' - '`~/tensorflow_datasets/` or `TFDS_DATA_DIR` environement variable.' - ), - ) - path_group.add_argument( - '--download_dir', - type=epath.Path, - help='Where to place downloads. Default to `/downloads/`.', - ) - path_group.add_argument( - '--extract_dir', - type=epath.Path, - help='Where to extract files. Default to `/extracted/`.', - ) - path_group.add_argument( - '--manual_dir', - type=epath.Path, - help=( - 'Where to manually download data (required for some datasets). ' - 'Default to `/manual/`.' - ), - ) - path_group.add_argument( - '--add_name_to_manual_dir', - action='store_true', - help=( - 'If true, append the dataset name to the `manual_dir` (e.g. ' - '`/manual//`. Useful to avoid collisions ' - 'if many datasets are generated.' - ), + overwrite: bool = simple_parsing.flag(default=False) + fail_if_exists: bool = simple_parsing.flag(default=False) + max_examples_per_split: int | None = simple_parsing.field( + default=None, nargs='?', const=1 ) -def add_generation_argument_group(parser: argparse.ArgumentParser): - """Adds generation argument group to the parser.""" - generation_group = parser.add_argument_group('Generation') - generation_group.add_argument( - '--download_only', - action='store_true', - help=( - 'If True, download all files but do not prepare the dataset. Uses the' - ' checksum.tsv to find out what to download. Therefore, this does not' - ' work in combination with --register_checksums.' - ), - ) - generation_group.add_argument( - '--config', - '-c', - type=str, - help=( - 'Config name to build. Build all configs if not set. Can also be a' - ' json of the kwargs forwarded to the config `__init__` (for custom' - ' configs).' - ), - ) - # We are forced to have 2 flags to avoid ambiguity when config name is - # a number (e.g. `voc/2017`) - generation_group.add_argument( - '--config_idx', - type=int, - help=( - 'Config id to build (`builder_cls.BUILDER_CONFIGS[config_idx]`). ' - 'Mutually exclusive with `--config`.' - ), - ) - generation_group.add_argument( - '--update_metadata_only', - action='store_true', - default=False, - help=( - 'If True, existing dataset_info.json is updated with metadata defined' - ' in Builder class(es). Datasets must already have been prepared.' - ), - ) - generation_group.add_argument( - '--download_config', - type=str, - help=( - 'A json of the kwargs forwarded to the config `__init__` (for custom' - ' DownloadConfigs).' - ), - ) - generation_group.add_argument( - '--imports', - '-i', - type=str, - help='Comma separated list of module to import to register datasets.', - ) - generation_group.add_argument( - '--register_checksums', - action='store_true', - help='If True, store size and checksum of downloaded files.', - ) - generation_group.add_argument( - '--force_checksums_validation', - action='store_true', - help='If True, raise an error if the checksums are not found.', - ) - # For compatibility with absl.flags (which generates --foo and --nofoo). - generation_group.add_argument( - '--noforce_checksums_validation', - dest='force_checksums_validation', - action='store_false', - help='If specified, bypass the checks on the checksums.', - ) - generation_group.add_argument( - '--beam_pipeline_options', - type=str, - # nargs='+', - help=( - 'A (comma-separated) list of flags to pass to `PipelineOptions` when' - ' preparing with Apache Beam. (see:' - ' https://www.tensorflow.org/datasets/beam_datasets). Example:' - ' `--beam_pipeline_options=job_name=my-job,project=my-project`' - ), - ) - format_values = [f.value for f in file_adapters.FileFormat] - generation_group.add_argument( - '--file_format', - type=str, - help=( - 'File format to which generate the tf-examples. ' - f'Available values: {format_values} (see `tfds.core.FileFormat`).' - ), - ) - generation_group.add_argument( - '--max_shard_size_mb', type=int, help='The max shard size in megabytes.' - ) - generation_group.add_argument( - '--num_shards', type=int, help='The number of shards to write to.' - ) - generation_group.add_argument( - '--num-processes', - type=int, - default=1, - help='Number of parallel build processes.', - ) - generation_group.add_argument( - '--nondeterministic_order', - action='store_true', - default=False, - help=( - 'If True, it will not assure deterministic ordering when writing' - ' examples to disk. This might result in quicker dataset preparation.' - ), - ) - # For compatibility with absl.flags (which generates --foo and --nofoo). - generation_group.add_argument( - '--nonondeterministic_order', - dest='nondeterministic_order', - action='store_false', - help=( - 'If specified, it will assure deterministic ordering when writing' - ' examples to disk.' - ), - ) +@dataclasses.dataclass(frozen=True, kw_only=True) +class PathOptions: + """Path options. + Attributes: + data_dir: Where to place datasets. Default to `~/tensorflow_datasets/` or + `TFDS_DATA_DIR` environement variable. + download_dir: Where to place downloads. Default to `/downloads/`. + extract_dir: Where to extract files. Default to `/extracted/`. + manual_dir: Where to manually download data (required for some datasets). + Default to `/manual/`. + add_name_to_manual_dir: If true, append the dataset name to the `manual_dir` + (e.g. `/manual//`). Useful to avoid collisions + if many datasets are generated. + """ -def add_publish_argument_group(parser: argparse.ArgumentParser): - """Adds publish argument group to the parser.""" - publish_group = parser.add_argument_group( - 'Publishing', - description='Options for publishing successfully created datasets.', + data_dir: epath.Path = simple_parsing.field( + default=epath.Path(constants.DATA_DIR) ) - publish_group.add_argument( - '--publish_dir', - type=epath.Path, + download_dir: epath.Path | None = None + extract_dir: epath.Path | None = None + manual_dir: epath.Path | None = None + add_name_to_manual_dir: bool = simple_parsing.flag(default=False) + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class GenerationOptions: + """Generation options. + + Attributes: + download_only: If True, download all files but do not prepare the dataset. + Uses the checksum.tsv to find out what to download. Therefore, this does + not work in combination with --register_checksums. + config: Config name to build. Build all configs if not set. Can also be a + json of the kwargs forwarded to the config `__init__` (for custom + configs). + config_idx: Config id to build (`builder_cls.BUILDER_CONFIGS[config_idx]`). + Mutually exclusive with `--config`. We are forced to have 2 flags to avoid + ambiguity when `config` is a number (e.g. `voc/2017`). + update_metadata_only: If True, existing dataset_info.json is updated with + metadata defined in Builder class(es). Datasets must already have been + prepared. + download_config: A json of the kwargs forwarded to the config `__init__` + (for custom DownloadConfigs). + imports: Comma separated list of module to import to register datasets. + register_checksums: If True, store size and checksum of downloaded files. + force_checksums_validation: If True, raise an error if the checksums are not + found. Otherwise, bypass the checks on the checksums + beam_pipeline_options: A (comma-separated) list of flags to pass to + `PipelineOptions` when preparing with Apache Beam. (see: + https://www.tensorflow.org/datasets/beam_datasets). Example: + `--beam_pipeline_options=job_name=my-job,project=my-project` + file_format: File format to which generate the tf-examples. + max_shard_size_mb: The max shard size in megabytes. + num_shards: The number of shards to write to. + num_processes: Number of parallel build processes. + nondeterministic_order: If True, it will not assure deterministic ordering + when writing examples to disk. This might result in quicker dataset + preparation. Otherwise, it will assure deterministic ordering when writing + examples to disk + """ + + download_only: bool = simple_parsing.flag(default=False) + config: str | None = simple_parsing.field(default=None, alias='-c') + config_idx: int | None = None + update_metadata_only: bool = simple_parsing.flag(default=False) + download_config: str | None = None + imports: str | None = simple_parsing.field(default=None, alias='-i') + register_checksums: bool = simple_parsing.flag(default=False) + force_checksums_validation: bool = simple_parsing.flag(default=False) + beam_pipeline_options: str | None = None + file_format: str | None = simple_parsing.choice( + *(file_format.value for file_format in file_adapters.FileFormat), default=None, - required=False, - help=( - 'Where to optionally publish the dataset after it has been ' - 'generated successfully. Should be the root data dir under which' - 'datasets are stored. ' - 'If unspecified, dataset will not be published' - ), - ) - publish_group.add_argument( - '--skip_if_published', - action='store_true', - default=False, - help=( - 'If the dataset with the same version and config is already ' - 'published, then it will not be regenerated.' - ), ) + max_shard_size_mb: int | None = None + num_shards: int | None = None + num_processes: int = simple_parsing.field(default=1, alias='num-processes') + nondeterministic_order: bool = simple_parsing.flag(default=False) + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class PublishingOptions: + """Publishing options. + + Attributes: + publish_dir: Where to optionally publish the dataset after it has been + generated successfully. Should be the root data dir under which datasets + are stored. If unspecified, dataset will not be published. + skip_if_published: If the dataset with the same version and config is + already published, then it will not be regenerated. + """ + + publish_dir: epath.Path | None = None + skip_if_published: bool = simple_parsing.flag(default=False) + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class AutomationOptions: + """Automation options. + + Attributes: + exclude_datasets: If set, generate all datasets except the one defined here. + Comma separated list of datasets to exclude. + experimental_latest_version: Build the latest Version(experiments=...) + available rather than default version. + """ + + exclude_datasets: str | None = None + experimental_latest_version: bool = simple_parsing.flag(default=False) def download_and_prepare( diff --git a/tensorflow_datasets/scripts/download_and_prepare.py b/tensorflow_datasets/scripts/download_and_prepare.py index 496b24afdc9..2898a370a06 100644 --- a/tensorflow_datasets/scripts/download_and_prepare.py +++ b/tensorflow_datasets/scripts/download_and_prepare.py @@ -16,12 +16,11 @@ r"""Wrapper around `tfds build`.""" import argparse -from typing import List from absl import app from absl import flags from absl import logging - +from tensorflow_datasets.scripts.cli import build from tensorflow_datasets.scripts.cli import main as main_cli module_import = flags.DEFINE_string('module_import', None, '`--imports` flag.') @@ -33,7 +32,7 @@ -def _parse_flags(argv: List[str]) -> argparse.Namespace: +def _parse_flags(argv: list[str]) -> argparse.Namespace: """Command lines flag parsing.""" return main_cli._parse_flags([argv[0], 'build'] + argv[1:]) # pylint: disable=protected-access @@ -46,12 +45,13 @@ def main(args: argparse.Namespace) -> None: logging.warning( '***`tfds build` should be used instead of `download_and_prepare`.***' ) + cmd_args: build.Args = args.args if module_import.value: - args.imports = module_import.value + cmd_args.generation.imports = module_import.value if dataset.value: - args.datasets = [dataset.value] + cmd_args.datasets = [dataset.value] if builder_config_id.value is not None: - args.config_idx = builder_config_id.value + cmd_args.generation.config_idx = builder_config_id.value main_cli.main(args)