17
17
18
18
import argparse
19
19
from collections .abc import Iterator
20
+ import dataclasses
20
21
import functools
21
22
import importlib
22
23
import itertools
23
24
import json
24
25
import multiprocessing
25
26
import os
27
+ import typing
26
28
from typing import Any , Type
27
29
28
30
from absl import logging
31
+ import simple_parsing
29
32
import tensorflow_datasets as tfds
30
33
from tensorflow_datasets .scripts .cli import cli_utils
31
34
32
- # pylint: disable=logging-fstring-interpolation
33
35
34
-
35
- def register_subparser (parsers : argparse ._SubParsersAction ) -> None : # pylint: disable=protected-access
36
- """Add subparser for `build` command.
37
-
38
- New flags should be added to `cli_utils` module.
39
-
40
- Args:
41
- parsers: The subparsers object to add the parser to.
36
+ @dataclasses .dataclass (frozen = True , kw_only = True )
37
+ class Args :
38
+ """CLI arguments for building datasets.
39
+
40
+ Attributes:
41
+ positional_datasets: Name(s) of the dataset(s) to build. Default to current
42
+ dir. See https://www.tensorflow.org/datasets/cli for accepted values.
43
+ datasets: Datasets can also be provided as keyword argument.
44
+ debug: Debug & tests options. Use --pdb to enter post-mortem debugging mode
45
+ if an exception is raised.
46
+ paths: Path options.
47
+ generation: Generation options.
48
+ publishing: Publishing options.
49
+ automation: Automation options.
42
50
"""
43
- build_parser = parsers .add_parser (
44
- 'build' , help = 'Commands for downloading and preparing datasets.'
45
- )
46
- build_parser .add_argument (
47
- 'datasets' , # Positional arguments
48
- type = str ,
51
+
52
+ positional_datasets : list [str ] = simple_parsing .field (
53
+ positional = True ,
49
54
nargs = '*' ,
50
- help = (
51
- 'Name(s) of the dataset(s) to build. Default to current dir. '
52
- 'See https://www.tensorflow.org/datasets/cli for accepted values.'
53
- ),
54
- )
55
- build_parser .add_argument ( # Also accept keyword arguments
56
- '--datasets' ,
57
- type = str ,
58
- nargs = '+' ,
59
- dest = 'datasets_keyword' ,
60
- help = 'Datasets can also be provided as keyword argument.' ,
55
+ default_factory = list ,
56
+ # Need to explicitly set metavar for command-line help.
57
+ metavar = 'datasets' ,
61
58
)
59
+ datasets : list [str ] = simple_parsing .field (nargs = '*' , default_factory = list )
62
60
63
- cli_utils .add_debug_argument_group (build_parser )
64
- cli_utils .add_path_argument_group (build_parser )
65
- cli_utils .add_generation_argument_group (build_parser )
66
- cli_utils .add_publish_argument_group (build_parser )
67
-
68
- # **** Automation options ****
69
- automation_group = build_parser .add_argument_group (
70
- 'Automation' , description = 'Used by automated scripts.'
61
+ debug : cli_utils .DebugOptions = cli_utils .DebugOptions ()
62
+ paths : cli_utils .PathOptions = simple_parsing .field (
63
+ default_factory = cli_utils .PathOptions
64
+ )
65
+ generation : cli_utils .GenerationOptions = simple_parsing .field (
66
+ default_factory = cli_utils .GenerationOptions
71
67
)
72
- automation_group .add_argument (
73
- '--exclude_datasets' ,
74
- type = str ,
75
- help = (
76
- 'If set, generate all datasets except the one defined here. '
77
- 'Comma separated list of datasets to exclude. '
78
- ),
68
+ publishing : cli_utils .PublishingOptions = simple_parsing .field (
69
+ default_factory = cli_utils .PublishingOptions
79
70
)
80
- automation_group .add_argument (
81
- '--experimental_latest_version' ,
82
- action = 'store_true' ,
83
- help = (
84
- 'Build the latest Version(experiments=...) available rather than '
85
- 'default version.'
86
- ),
71
+ automation : cli_utils .AutomationOptions = simple_parsing .field (
72
+ default_factory = cli_utils .AutomationOptions
87
73
)
88
74
89
- build_parser .set_defaults (subparser_fn = _build_datasets )
75
+ def execute (self ) -> None :
76
+ """Build the given datasets."""
77
+ # Eventually register additional datasets imports
78
+ if self .generation .imports :
79
+ list (
80
+ importlib .import_module (m ) for m in self .generation .imports .split (',' )
81
+ )
90
82
83
+ # Select datasets to generate
84
+ datasets = self .positional_datasets + self .datasets
85
+ if (
86
+ self .automation .exclude_datasets
87
+ ): # Generate all datasets if `--exclude_datasets` set
88
+ if datasets :
89
+ raise ValueError ("--exclude_datasets can't be used with `datasets`" )
90
+ datasets = set (tfds .list_builders (with_community_datasets = False )) - set (
91
+ self .automation .exclude_datasets .split (',' )
92
+ )
93
+ datasets = sorted (datasets ) # `set` is not deterministic
94
+ else :
95
+ datasets = datasets or ['' ] # Empty string for default
96
+
97
+ # Import builder classes
98
+ builders_cls_and_kwargs = [
99
+ _get_builder_cls_and_kwargs (
100
+ dataset , has_imports = bool (self .generation .imports )
101
+ )
102
+ for dataset in datasets
103
+ ]
104
+
105
+ # Parallelize datasets generation.
106
+ builders = itertools .chain (* (
107
+ _make_builders (self , builder_cls , builder_kwargs )
108
+ for (builder_cls , builder_kwargs ) in builders_cls_and_kwargs
109
+ ))
110
+ process_builder_fn = functools .partial (
111
+ _download if self .generation .download_only else _download_and_prepare ,
112
+ self ,
113
+ )
91
114
92
- def _build_datasets (args : argparse .Namespace ) -> None :
93
- """Build the given datasets."""
94
- # Eventually register additional datasets imports
95
- if args .imports :
96
- list (importlib .import_module (m ) for m in args .imports .split (',' ))
115
+ if self .generation .num_processes == 1 :
116
+ for builder in builders :
117
+ process_builder_fn (builder )
118
+ else :
119
+ with multiprocessing .Pool (self .generation .num_processes ) as pool :
120
+ pool .map (process_builder_fn , builders )
97
121
98
- # Select datasets to generate
99
- datasets = (args .datasets or []) + (args .datasets_keyword or [])
100
- if args .exclude_datasets : # Generate all datasets if `--exclude_datasets` set
101
- if datasets :
102
- raise ValueError ("--exclude_datasets can't be used with `datasets`" )
103
- datasets = set (tfds .list_builders (with_community_datasets = False )) - set (
104
- args .exclude_datasets .split (',' )
105
- )
106
- datasets = sorted (datasets ) # `set` is not deterministic
107
- else :
108
- datasets = datasets or ['' ] # Empty string for default
109
-
110
- # Import builder classes
111
- builders_cls_and_kwargs = [
112
- _get_builder_cls_and_kwargs (dataset , has_imports = bool (args .imports ))
113
- for dataset in datasets
114
- ]
115
-
116
- # Parallelize datasets generation.
117
- builders = itertools .chain (* (
118
- _make_builders (args , builder_cls , builder_kwargs )
119
- for (builder_cls , builder_kwargs ) in builders_cls_and_kwargs
120
- ))
121
- process_builder_fn = functools .partial (
122
- _download if args .download_only else _download_and_prepare , args
123
- )
124
122
125
- if args .num_processes == 1 :
126
- for builder in builders :
127
- process_builder_fn (builder )
128
- else :
129
- with multiprocessing .Pool (args .num_processes ) as pool :
130
- pool .map (process_builder_fn , builders )
123
+ def register_subparser (parsers : argparse ._SubParsersAction ) -> None : # pylint: disable=protected-access
124
+ """Add subparser for `build` command."""
125
+ parser = parsers .add_parser (
126
+ 'build' , help = 'Commands for downloading and preparing datasets.'
127
+ )
128
+ parser = typing .cast (simple_parsing .ArgumentParser , parser )
129
+ parser .add_arguments (Args , dest = 'args' )
130
+ parser .set_defaults (subparser_fn = lambda args : args .args .execute ())
131
131
132
132
133
133
def _make_builders (
134
- args : argparse . Namespace ,
134
+ args : Args ,
135
135
builder_cls : Type [tfds .core .DatasetBuilder ],
136
136
builder_kwargs : dict [str , Any ],
137
137
) -> Iterator [tfds .core .DatasetBuilder ]:
@@ -146,7 +146,7 @@ def _make_builders(
146
146
Initialized dataset builders.
147
147
"""
148
148
# Eventually overwrite version
149
- if args .experimental_latest_version :
149
+ if args .automation . experimental_latest_version :
150
150
if 'version' in builder_kwargs :
151
151
raise ValueError (
152
152
"Can't have both `--experimental_latest` and version set (`:1.0.0`)"
@@ -157,19 +157,19 @@ def _make_builders(
157
157
builder_kwargs ['config' ] = _get_config_name (
158
158
builder_cls = builder_cls ,
159
159
config_kwarg = builder_kwargs .get ('config' ),
160
- config_name = args .config ,
161
- config_idx = args .config_idx ,
160
+ config_name = args .generation . config ,
161
+ config_idx = args .generation . config_idx ,
162
162
)
163
163
164
- if args .file_format :
165
- builder_kwargs ['file_format' ] = args .file_format
164
+ if args .generation . file_format :
165
+ builder_kwargs ['file_format' ] = args .generation . file_format
166
166
167
167
make_builder = functools .partial (
168
168
_make_builder ,
169
169
builder_cls ,
170
- overwrite = args .overwrite ,
171
- fail_if_exists = args .fail_if_exists ,
172
- data_dir = args .data_dir ,
170
+ overwrite = args .debug . overwrite ,
171
+ fail_if_exists = args .debug . fail_if_exists ,
172
+ data_dir = args .paths . data_dir ,
173
173
** builder_kwargs ,
174
174
)
175
175
@@ -203,7 +203,7 @@ def _get_builder_cls_and_kwargs(
203
203
if not has_imports :
204
204
path = _search_script_path (ds_to_build )
205
205
if path is not None :
206
- logging .info (f 'Loading dataset { ds_to_build } from path: { path } ' )
206
+ logging .info ('Loading dataset %s from path: %s' , ds_to_build , path )
207
207
# Dynamically load user dataset script
208
208
# When possible, load from the parent's parent, so module is named
209
209
# "foo.foo_dataset_builder".
@@ -228,7 +228,9 @@ def _get_builder_cls_and_kwargs(
228
228
name , builder_kwargs = tfds .core .naming .parse_builder_name_kwargs (ds_to_build )
229
229
builder_cls = tfds .builder_cls (str (name ))
230
230
logging .info (
231
- f'Loading dataset { ds_to_build } from imports: { builder_cls .__module__ } '
231
+ 'Loading dataset %s from imports: %s' ,
232
+ ds_to_build ,
233
+ builder_cls .__module__ ,
232
234
)
233
235
return builder_cls , builder_kwargs
234
236
@@ -308,7 +310,7 @@ def _make_builder(
308
310
309
311
310
312
def _download (
311
- args : argparse . Namespace ,
313
+ args : Args ,
312
314
builder : tfds .core .DatasetBuilder ,
313
315
) -> None :
314
316
"""Downloads all files of the given builder."""
@@ -330,7 +332,7 @@ def _download(
330
332
if builder .MAX_SIMULTANEOUS_DOWNLOADS is not None :
331
333
max_simultaneous_downloads = builder .MAX_SIMULTANEOUS_DOWNLOADS
332
334
333
- download_dir = args .download_dir or os .path .join (
335
+ download_dir = args .paths . download_dir or os .path .join (
334
336
builder ._data_dir_root , 'downloads' # pylint: disable=protected-access
335
337
)
336
338
dl_manager = tfds .download .DownloadManager (
@@ -352,55 +354,55 @@ def _download(
352
354
353
355
354
356
def _download_and_prepare (
355
- args : argparse . Namespace ,
357
+ args : Args ,
356
358
builder : tfds .core .DatasetBuilder ,
357
359
) -> None :
358
360
"""Generate a single builder."""
359
361
cli_utils .download_and_prepare (
360
362
builder = builder ,
361
363
download_config = _make_download_config (args , dataset_name = builder .name ),
362
- download_dir = args .download_dir ,
363
- publish_dir = args .publish_dir ,
364
- skip_if_published = args .skip_if_published ,
365
- overwrite = args .overwrite ,
366
- beam_pipeline_options = args .beam_pipeline_options ,
367
- nondeterministic_order = args .nondeterministic_order ,
364
+ download_dir = args .paths . download_dir ,
365
+ publish_dir = args .publishing . publish_dir ,
366
+ skip_if_published = args .publishing . skip_if_published ,
367
+ overwrite = args .debug . overwrite ,
368
+ beam_pipeline_options = args .generation . beam_pipeline_options ,
369
+ nondeterministic_order = args .generation . nondeterministic_order ,
368
370
)
369
371
370
372
371
373
def _make_download_config (
372
- args : argparse . Namespace ,
374
+ args : Args ,
373
375
dataset_name : str ,
374
376
) -> tfds .download .DownloadConfig :
375
377
"""Generate the download and prepare configuration."""
376
378
# Load the download config
377
- manual_dir = args .manual_dir
378
- if args .add_name_to_manual_dir :
379
+ manual_dir = args .paths . manual_dir
380
+ if args .paths . add_name_to_manual_dir :
379
381
manual_dir = manual_dir / dataset_name
380
382
381
383
kwargs = {}
382
- if args .max_shard_size_mb :
383
- kwargs ['max_shard_size' ] = args .max_shard_size_mb << 20
384
- if args .num_shards :
385
- kwargs ['num_shards' ] = args .num_shards
386
- if args .download_config :
387
- kwargs .update (json .loads (args .download_config ))
384
+ if args .generation . max_shard_size_mb :
385
+ kwargs ['max_shard_size' ] = args .generation . max_shard_size_mb << 20
386
+ if args .generation . num_shards :
387
+ kwargs ['num_shards' ] = args .generation . num_shards
388
+ if args .generation . download_config :
389
+ kwargs .update (json .loads (args .generation . download_config ))
388
390
389
391
if 'download_mode' in kwargs :
390
392
kwargs ['download_mode' ] = tfds .download .GenerateMode (
391
393
kwargs ['download_mode' ]
392
394
)
393
395
else :
394
396
kwargs ['download_mode' ] = tfds .download .GenerateMode .REUSE_DATASET_IF_EXISTS
395
- if args .update_metadata_only :
397
+ if args .generation . update_metadata_only :
396
398
kwargs ['download_mode' ] = tfds .download .GenerateMode .UPDATE_DATASET_INFO
397
399
398
400
return tfds .download .DownloadConfig (
399
- extract_dir = args .extract_dir ,
401
+ extract_dir = args .paths . extract_dir ,
400
402
manual_dir = manual_dir ,
401
- max_examples_per_split = args .max_examples_per_split ,
402
- register_checksums = args .register_checksums ,
403
- force_checksums_validation = args .force_checksums_validation ,
403
+ max_examples_per_split = args .debug . max_examples_per_split ,
404
+ register_checksums = args .generation . register_checksums ,
405
+ force_checksums_validation = args .generation . force_checksums_validation ,
404
406
** kwargs ,
405
407
)
406
408
@@ -445,11 +447,10 @@ def _get_config_name(
445
447
else :
446
448
return config_name
447
449
elif config_idx is not None : # `--config_idx 123`
448
- if config_idx > len (builder_cls .BUILDER_CONFIGS ):
450
+ if config_idx >= len (builder_cls .BUILDER_CONFIGS ):
449
451
raise ValueError (
450
- f'--config_idx { config_idx } greater than number '
451
- f'of configs { len (builder_cls .BUILDER_CONFIGS )} for '
452
- f'{ builder_cls .name } .'
452
+ f'--config_idx { config_idx } greater than number of configs '
453
+ f'{ len (builder_cls .BUILDER_CONFIGS )} for { builder_cls .name } .'
453
454
)
454
455
else :
455
456
# Use `config.name` to avoid
0 commit comments