Skip to content

Commit af58e19

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Use simple_parsing for build cli command.
PiperOrigin-RevId: 792093446
1 parent 77ffb39 commit af58e19

File tree

4 files changed

+254
-347
lines changed

4 files changed

+254
-347
lines changed

tensorflow_datasets/scripts/cli/build.py

Lines changed: 125 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -17,121 +17,121 @@
1717

1818
import argparse
1919
from collections.abc import Iterator
20+
import dataclasses
2021
import functools
2122
import importlib
2223
import itertools
2324
import json
2425
import multiprocessing
2526
import os
27+
import typing
2628
from typing import Any, Type
2729

2830
from absl import logging
31+
import simple_parsing
2932
import tensorflow_datasets as tfds
3033
from tensorflow_datasets.scripts.cli import cli_utils
3134

32-
# pylint: disable=logging-fstring-interpolation
3335

34-
35-
def register_subparser(parsers: argparse._SubParsersAction) -> None: # pylint: disable=protected-access
36-
"""Add subparser for `build` command.
37-
38-
New flags should be added to `cli_utils` module.
39-
40-
Args:
41-
parsers: The subparsers object to add the parser to.
36+
@dataclasses.dataclass(frozen=True, kw_only=True)
37+
class Args:
38+
"""CLI arguments for building datasets.
39+
40+
Attributes:
41+
positional_datasets: Name(s) of the dataset(s) to build. Default to current
42+
dir. See https://www.tensorflow.org/datasets/cli for accepted values.
43+
datasets: Datasets can also be provided as keyword argument.
44+
debug: Debug & tests options. Use --pdb to enter post-mortem debugging mode
45+
if an exception is raised.
46+
paths: Path options.
47+
generation: Generation options.
48+
publishing: Publishing options.
49+
automation: Automation options.
4250
"""
43-
build_parser = parsers.add_parser(
44-
'build', help='Commands for downloading and preparing datasets.'
45-
)
46-
build_parser.add_argument(
47-
'datasets', # Positional arguments
48-
type=str,
51+
52+
positional_datasets: list[str] = simple_parsing.field(
53+
positional=True,
4954
nargs='*',
50-
help=(
51-
'Name(s) of the dataset(s) to build. Default to current dir. '
52-
'See https://www.tensorflow.org/datasets/cli for accepted values.'
53-
),
54-
)
55-
build_parser.add_argument( # Also accept keyword arguments
56-
'--datasets',
57-
type=str,
58-
nargs='+',
59-
dest='datasets_keyword',
60-
help='Datasets can also be provided as keyword argument.',
55+
default_factory=list,
56+
# Need to explicitly set metavar for command-line help.
57+
metavar='datasets',
6158
)
59+
datasets: list[str] = simple_parsing.field(nargs='*', default_factory=list)
6260

63-
cli_utils.add_debug_argument_group(build_parser)
64-
cli_utils.add_path_argument_group(build_parser)
65-
cli_utils.add_generation_argument_group(build_parser)
66-
cli_utils.add_publish_argument_group(build_parser)
67-
68-
# **** Automation options ****
69-
automation_group = build_parser.add_argument_group(
70-
'Automation', description='Used by automated scripts.'
61+
debug: cli_utils.DebugOptions = cli_utils.DebugOptions()
62+
paths: cli_utils.PathOptions = simple_parsing.field(
63+
default_factory=cli_utils.PathOptions
64+
)
65+
generation: cli_utils.GenerationOptions = simple_parsing.field(
66+
default_factory=cli_utils.GenerationOptions
7167
)
72-
automation_group.add_argument(
73-
'--exclude_datasets',
74-
type=str,
75-
help=(
76-
'If set, generate all datasets except the one defined here. '
77-
'Comma separated list of datasets to exclude. '
78-
),
68+
publishing: cli_utils.PublishingOptions = simple_parsing.field(
69+
default_factory=cli_utils.PublishingOptions
7970
)
80-
automation_group.add_argument(
81-
'--experimental_latest_version',
82-
action='store_true',
83-
help=(
84-
'Build the latest Version(experiments=...) available rather than '
85-
'default version.'
86-
),
71+
automation: cli_utils.AutomationOptions = simple_parsing.field(
72+
default_factory=cli_utils.AutomationOptions
8773
)
8874

89-
build_parser.set_defaults(subparser_fn=_build_datasets)
75+
def execute(self) -> None:
76+
"""Build the given datasets."""
77+
# Eventually register additional datasets imports
78+
if self.generation.imports:
79+
list(
80+
importlib.import_module(m) for m in self.generation.imports.split(',')
81+
)
9082

83+
# Select datasets to generate
84+
datasets = self.positional_datasets + self.datasets
85+
if (
86+
self.automation.exclude_datasets
87+
): # Generate all datasets if `--exclude_datasets` set
88+
if datasets:
89+
raise ValueError("--exclude_datasets can't be used with `datasets`")
90+
datasets = set(tfds.list_builders(with_community_datasets=False)) - set(
91+
self.automation.exclude_datasets.split(',')
92+
)
93+
datasets = sorted(datasets) # `set` is not deterministic
94+
else:
95+
datasets = datasets or [''] # Empty string for default
96+
97+
# Import builder classes
98+
builders_cls_and_kwargs = [
99+
_get_builder_cls_and_kwargs(
100+
dataset, has_imports=bool(self.generation.imports)
101+
)
102+
for dataset in datasets
103+
]
104+
105+
# Parallelize datasets generation.
106+
builders = itertools.chain(*(
107+
_make_builders(self, builder_cls, builder_kwargs)
108+
for (builder_cls, builder_kwargs) in builders_cls_and_kwargs
109+
))
110+
process_builder_fn = functools.partial(
111+
_download if self.generation.download_only else _download_and_prepare,
112+
self,
113+
)
91114

92-
def _build_datasets(args: argparse.Namespace) -> None:
93-
"""Build the given datasets."""
94-
# Eventually register additional datasets imports
95-
if args.imports:
96-
list(importlib.import_module(m) for m in args.imports.split(','))
115+
if self.generation.num_processes == 1:
116+
for builder in builders:
117+
process_builder_fn(builder)
118+
else:
119+
with multiprocessing.Pool(self.generation.num_processes) as pool:
120+
pool.map(process_builder_fn, builders)
97121

98-
# Select datasets to generate
99-
datasets = (args.datasets or []) + (args.datasets_keyword or [])
100-
if args.exclude_datasets: # Generate all datasets if `--exclude_datasets` set
101-
if datasets:
102-
raise ValueError("--exclude_datasets can't be used with `datasets`")
103-
datasets = set(tfds.list_builders(with_community_datasets=False)) - set(
104-
args.exclude_datasets.split(',')
105-
)
106-
datasets = sorted(datasets) # `set` is not deterministic
107-
else:
108-
datasets = datasets or [''] # Empty string for default
109-
110-
# Import builder classes
111-
builders_cls_and_kwargs = [
112-
_get_builder_cls_and_kwargs(dataset, has_imports=bool(args.imports))
113-
for dataset in datasets
114-
]
115-
116-
# Parallelize datasets generation.
117-
builders = itertools.chain(*(
118-
_make_builders(args, builder_cls, builder_kwargs)
119-
for (builder_cls, builder_kwargs) in builders_cls_and_kwargs
120-
))
121-
process_builder_fn = functools.partial(
122-
_download if args.download_only else _download_and_prepare, args
123-
)
124122

125-
if args.num_processes == 1:
126-
for builder in builders:
127-
process_builder_fn(builder)
128-
else:
129-
with multiprocessing.Pool(args.num_processes) as pool:
130-
pool.map(process_builder_fn, builders)
123+
def register_subparser(parsers: argparse._SubParsersAction) -> None: # pylint: disable=protected-access
124+
"""Add subparser for `build` command."""
125+
parser = parsers.add_parser(
126+
'build', help='Commands for downloading and preparing datasets.'
127+
)
128+
parser = typing.cast(simple_parsing.ArgumentParser, parser)
129+
parser.add_arguments(Args, dest='args')
130+
parser.set_defaults(subparser_fn=lambda args: args.args.execute())
131131

132132

133133
def _make_builders(
134-
args: argparse.Namespace,
134+
args: Args,
135135
builder_cls: Type[tfds.core.DatasetBuilder],
136136
builder_kwargs: dict[str, Any],
137137
) -> Iterator[tfds.core.DatasetBuilder]:
@@ -146,7 +146,7 @@ def _make_builders(
146146
Initialized dataset builders.
147147
"""
148148
# Eventually overwrite version
149-
if args.experimental_latest_version:
149+
if args.automation.experimental_latest_version:
150150
if 'version' in builder_kwargs:
151151
raise ValueError(
152152
"Can't have both `--experimental_latest` and version set (`:1.0.0`)"
@@ -157,19 +157,19 @@ def _make_builders(
157157
builder_kwargs['config'] = _get_config_name(
158158
builder_cls=builder_cls,
159159
config_kwarg=builder_kwargs.get('config'),
160-
config_name=args.config,
161-
config_idx=args.config_idx,
160+
config_name=args.generation.config,
161+
config_idx=args.generation.config_idx,
162162
)
163163

164-
if args.file_format:
165-
builder_kwargs['file_format'] = args.file_format
164+
if args.generation.file_format:
165+
builder_kwargs['file_format'] = args.generation.file_format
166166

167167
make_builder = functools.partial(
168168
_make_builder,
169169
builder_cls,
170-
overwrite=args.overwrite,
171-
fail_if_exists=args.fail_if_exists,
172-
data_dir=args.data_dir,
170+
overwrite=args.debug.overwrite,
171+
fail_if_exists=args.debug.fail_if_exists,
172+
data_dir=args.paths.data_dir,
173173
**builder_kwargs,
174174
)
175175

@@ -203,7 +203,7 @@ def _get_builder_cls_and_kwargs(
203203
if not has_imports:
204204
path = _search_script_path(ds_to_build)
205205
if path is not None:
206-
logging.info(f'Loading dataset {ds_to_build} from path: {path}')
206+
logging.info('Loading dataset %s from path: %s', ds_to_build, path)
207207
# Dynamically load user dataset script
208208
# When possible, load from the parent's parent, so module is named
209209
# "foo.foo_dataset_builder".
@@ -228,7 +228,9 @@ def _get_builder_cls_and_kwargs(
228228
name, builder_kwargs = tfds.core.naming.parse_builder_name_kwargs(ds_to_build)
229229
builder_cls = tfds.builder_cls(str(name))
230230
logging.info(
231-
f'Loading dataset {ds_to_build} from imports: {builder_cls.__module__}'
231+
'Loading dataset %s from imports: %s',
232+
ds_to_build,
233+
builder_cls.__module__,
232234
)
233235
return builder_cls, builder_kwargs
234236

@@ -308,7 +310,7 @@ def _make_builder(
308310

309311

310312
def _download(
311-
args: argparse.Namespace,
313+
args: Args,
312314
builder: tfds.core.DatasetBuilder,
313315
) -> None:
314316
"""Downloads all files of the given builder."""
@@ -330,7 +332,7 @@ def _download(
330332
if builder.MAX_SIMULTANEOUS_DOWNLOADS is not None:
331333
max_simultaneous_downloads = builder.MAX_SIMULTANEOUS_DOWNLOADS
332334

333-
download_dir = args.download_dir or os.path.join(
335+
download_dir = args.paths.download_dir or os.path.join(
334336
builder._data_dir_root, 'downloads' # pylint: disable=protected-access
335337
)
336338
dl_manager = tfds.download.DownloadManager(
@@ -352,55 +354,55 @@ def _download(
352354

353355

354356
def _download_and_prepare(
355-
args: argparse.Namespace,
357+
args: Args,
356358
builder: tfds.core.DatasetBuilder,
357359
) -> None:
358360
"""Generate a single builder."""
359361
cli_utils.download_and_prepare(
360362
builder=builder,
361363
download_config=_make_download_config(args, dataset_name=builder.name),
362-
download_dir=args.download_dir,
363-
publish_dir=args.publish_dir,
364-
skip_if_published=args.skip_if_published,
365-
overwrite=args.overwrite,
366-
beam_pipeline_options=args.beam_pipeline_options,
367-
nondeterministic_order=args.nondeterministic_order,
364+
download_dir=args.paths.download_dir,
365+
publish_dir=args.publishing.publish_dir,
366+
skip_if_published=args.publishing.skip_if_published,
367+
overwrite=args.debug.overwrite,
368+
beam_pipeline_options=args.generation.beam_pipeline_options,
369+
nondeterministic_order=args.generation.nondeterministic_order,
368370
)
369371

370372

371373
def _make_download_config(
372-
args: argparse.Namespace,
374+
args: Args,
373375
dataset_name: str,
374376
) -> tfds.download.DownloadConfig:
375377
"""Generate the download and prepare configuration."""
376378
# Load the download config
377-
manual_dir = args.manual_dir
378-
if args.add_name_to_manual_dir:
379+
manual_dir = args.paths.manual_dir
380+
if args.paths.add_name_to_manual_dir:
379381
manual_dir = manual_dir / dataset_name
380382

381383
kwargs = {}
382-
if args.max_shard_size_mb:
383-
kwargs['max_shard_size'] = args.max_shard_size_mb << 20
384-
if args.num_shards:
385-
kwargs['num_shards'] = args.num_shards
386-
if args.download_config:
387-
kwargs.update(json.loads(args.download_config))
384+
if args.generation.max_shard_size_mb:
385+
kwargs['max_shard_size'] = args.generation.max_shard_size_mb << 20
386+
if args.generation.num_shards:
387+
kwargs['num_shards'] = args.generation.num_shards
388+
if args.generation.download_config:
389+
kwargs.update(json.loads(args.generation.download_config))
388390

389391
if 'download_mode' in kwargs:
390392
kwargs['download_mode'] = tfds.download.GenerateMode(
391393
kwargs['download_mode']
392394
)
393395
else:
394396
kwargs['download_mode'] = tfds.download.GenerateMode.REUSE_DATASET_IF_EXISTS
395-
if args.update_metadata_only:
397+
if args.generation.update_metadata_only:
396398
kwargs['download_mode'] = tfds.download.GenerateMode.UPDATE_DATASET_INFO
397399

398400
return tfds.download.DownloadConfig(
399-
extract_dir=args.extract_dir,
401+
extract_dir=args.paths.extract_dir,
400402
manual_dir=manual_dir,
401-
max_examples_per_split=args.max_examples_per_split,
402-
register_checksums=args.register_checksums,
403-
force_checksums_validation=args.force_checksums_validation,
403+
max_examples_per_split=args.debug.max_examples_per_split,
404+
register_checksums=args.generation.register_checksums,
405+
force_checksums_validation=args.generation.force_checksums_validation,
404406
**kwargs,
405407
)
406408

@@ -445,11 +447,10 @@ def _get_config_name(
445447
else:
446448
return config_name
447449
elif config_idx is not None: # `--config_idx 123`
448-
if config_idx > len(builder_cls.BUILDER_CONFIGS):
450+
if config_idx >= len(builder_cls.BUILDER_CONFIGS):
449451
raise ValueError(
450-
f'--config_idx {config_idx} greater than number '
451-
f'of configs {len(builder_cls.BUILDER_CONFIGS)} for '
452-
f'{builder_cls.name}.'
452+
f'--config_idx {config_idx} greater than number of configs '
453+
f'{len(builder_cls.BUILDER_CONFIGS)} for {builder_cls.name}.'
453454
)
454455
else:
455456
# Use `config.name` to avoid

tensorflow_datasets/scripts/cli/build_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,8 @@ def test_download_only(build):
311311
)
312312
def test_make_download_config(args: str, download_config_kwargs):
313313
args = main._parse_flags(f'tfds build x {args}'.split())
314-
actual = build_lib._make_download_config(args, dataset_name='x')
314+
cmd_args: build_lib.Args = args.args
315+
actual = build_lib._make_download_config(cmd_args, dataset_name='x')
315316
# Ignore the beam runner
316317
actual = actual.replace(beam_runner=None)
317318
expected = tfds.download.DownloadConfig(**download_config_kwargs)

0 commit comments

Comments
 (0)