21
21
--jsonld=/tmp/croissant.json \
22
22
--data_dir=/tmp/foo \
23
23
--file_format=array_record \
24
- --record_sets=record1 --record_sets= record2 \
24
+ --record_sets=record1, record2 \
25
25
--mapping='{"document.csv": "~/Downloads/document.csv"}"'
26
26
```
27
27
"""
28
28
29
29
import argparse
30
- from collections . abc import Sequence
30
+ import dataclasses
31
31
import json
32
+ import typing
32
33
33
34
from etils import epath
35
+ import simple_parsing
36
+ from tensorflow_datasets .core import file_adapters
34
37
from tensorflow_datasets .core .dataset_builders import croissant_builder
35
38
from tensorflow_datasets .scripts .cli import cli_utils
36
39
37
40
38
- def add_parser_arguments (parser : argparse .ArgumentParser ):
39
- """Add arguments for `build_croissant` subparser."""
40
- parser .add_argument (
41
- '--jsonld' ,
42
- type = str ,
43
- help = 'The Croissant config file for the given dataset.' ,
44
- required = True ,
45
- )
46
- parser .add_argument (
47
- '--record_sets' ,
48
- nargs = '*' ,
49
- help = (
50
- 'The names of the record sets to generate. Each record set will'
51
- ' correspond to a separate config. If not specified, it will use all'
52
- ' the record sets'
53
- ),
54
- )
55
- parser .add_argument (
56
- '--mapping' ,
57
- type = str ,
58
- help = (
59
- 'Mapping filename->filepath as a Python dict[str, str] to handle'
60
- ' manual downloads. If `document.csv` is the FileObject and you'
61
- ' downloaded it to `~/Downloads/document.csv`, you can'
62
- ' specify`--mapping=\' {"document.csv": "~/Downloads/document.csv"}\' '
63
- ),
64
- )
65
-
66
- cli_utils .add_debug_argument_group (parser )
67
- cli_utils .add_path_argument_group (parser )
68
- cli_utils .add_generation_argument_group (parser )
69
- cli_utils .add_publish_argument_group (parser )
41
+ @dataclasses .dataclass
42
+ class CmdArgs :
43
+ """CLI arguments for preparing a Croissant dataset.
70
44
71
-
72
- def register_subparser (parsers : argparse ._SubParsersAction ):
73
- """Add subparser for `convert_format` command."""
74
- parser = parsers .add_parser (
75
- 'build_croissant' ,
76
- help = 'Prepares a croissant dataset' ,
77
- )
78
- add_parser_arguments (parser )
79
- parser .set_defaults (
80
- subparser_fn = lambda args : prepare_croissant_builder (
81
- jsonld = args .jsonld ,
82
- data_dir = args .data_dir ,
83
- file_format = args .file_format ,
84
- record_sets = args .record_sets ,
85
- mapping = args .mapping ,
86
- download_dir = args .download_dir ,
87
- publish_dir = args .publish_dir ,
88
- skip_if_published = args .skip_if_published ,
89
- overwrite = args .overwrite ,
90
- )
91
- )
92
-
93
-
94
- def prepare_croissant_builder (
95
- jsonld : epath .PathLike ,
96
- data_dir : epath .PathLike ,
97
- file_format : str ,
98
- record_sets : Sequence [str ],
99
- mapping : str | None ,
100
- download_dir : epath .PathLike | None ,
101
- publish_dir : epath .PathLike | None ,
102
- skip_if_published : bool ,
103
- overwrite : bool ,
104
- ) -> None :
105
- # pyformat: disable
106
- """Creates a Croissant Builder and runs the preparation.
107
-
108
- Args:
109
- jsonld: The Croissant config file for the given dataset
45
+ Attributes:
46
+ jsonld: Path to the JSONLD file.
110
47
data_dir: Path where the converted dataset will be stored.
111
48
file_format: File format to convert the dataset to.
112
- record_sets: The `@id`s of the record sets to generate. Each record set will
49
+ record_sets: The names of the record sets to generate. Each record set will
113
50
correspond to a separate config. If not specified, it will use all the
114
- record sets
51
+ record sets.
115
52
mapping: Mapping filename->filepath as a Python dict[str, str] to handle
116
53
manual downloads. If `document.csv` is the FileObject and you downloaded
117
54
it to `~/Downloads/document.csv`, you can specify
118
- `mapping={"document.csv": "~/Downloads/document.csv"}`.,
55
+ `-- mapping=' {"document.csv": "~/Downloads/document.csv"}'`
119
56
download_dir: Where to place downloads. Default to `<data_dir>/downloads/`.
120
57
publish_dir: Where to optionally publish the dataset after it has been
121
58
generated successfully. Should be the root data dir under which datasets
@@ -124,29 +61,74 @@ def prepare_croissant_builder(
124
61
already published, then it will not be regenerated.
125
62
overwrite: Delete pre-existing dataset if it exists.
126
63
"""
127
- # pyformat: enable
128
- if not record_sets :
129
- record_sets = None
130
64
131
- if mapping :
65
+ jsonld : epath .PathLike
66
+ data_dir : epath .PathLike
67
+ # Need to override the default use of `Enum.name` for choice options.
68
+ file_format : str = simple_parsing .choice (
69
+ * (file_format .value for file_format in file_adapters .FileFormat ),
70
+ default = file_adapters .FileFormat .ARRAY_RECORD .value ,
71
+ )
72
+ # Need to manually parse comma-separated list of values, see:
73
+ # https://github.com/lebrice/SimpleParsing/issues/142.
74
+ record_sets : list [str ] = simple_parsing .field (
75
+ default_factory = list ,
76
+ type = lambda record_sets_str : record_sets_str .split (',' ),
77
+ nargs = '?' ,
78
+ )
79
+ mapping : str | None = None
80
+ download_dir : epath .PathLike | None = None
81
+ publish_dir : epath .PathLike | None = None
82
+ skip_if_published : bool = False
83
+ overwrite : bool = False
84
+
85
+
86
+ def register_subparser (parsers : argparse ._SubParsersAction ):
87
+ """Add subparser for `convert_format` command."""
88
+ orig_parser_class = parsers ._parser_class # pylint: disable=protected-access
89
+ try :
90
+ parsers ._parser_class = simple_parsing .ArgumentParser # pylint: disable=protected-access
91
+ parser = parsers .add_parser (
92
+ 'build_croissant' ,
93
+ help = 'Prepares a croissant dataset' ,
94
+ )
95
+ parser = typing .cast (simple_parsing .ArgumentParser , parser )
96
+ finally :
97
+ parsers ._parser_class = orig_parser_class # pylint: disable=protected-access
98
+ parser .add_arguments (CmdArgs , dest = 'args' )
99
+ parser .set_defaults (
100
+ subparser_fn = lambda args : prepare_croissant_builder (args .args )
101
+ )
102
+
103
+
104
+ def prepare_croissant_builder (args : CmdArgs ) -> None :
105
+ """Creates a Croissant Builder and runs the preparation.
106
+
107
+ Args:
108
+ args: CLI arguments.
109
+ """
110
+ if args .mapping :
132
111
try :
133
- mapping = json .loads (mapping )
112
+ mapping = json .loads (args . mapping )
134
113
except json .JSONDecodeError as e :
135
- raise ValueError (f'Error parsing mapping parameter: { mapping } ' ) from e
114
+ raise ValueError (
115
+ f'Error parsing mapping parameter: { args .mapping } '
116
+ ) from e
117
+ else :
118
+ mapping = None
136
119
137
120
builder = croissant_builder .CroissantBuilder (
138
- jsonld = jsonld ,
139
- record_set_ids = record_sets ,
140
- file_format = file_format ,
141
- data_dir = data_dir ,
121
+ jsonld = args . jsonld ,
122
+ record_set_ids = args . record_sets or None ,
123
+ file_format = args . file_format ,
124
+ data_dir = args . data_dir ,
142
125
mapping = mapping ,
143
126
)
144
127
cli_utils .download_and_prepare (
145
128
builder = builder ,
146
129
download_config = None ,
147
- download_dir = epath .Path (download_dir ) if download_dir else None ,
148
- publish_dir = epath .Path (publish_dir ) if publish_dir else None ,
149
- skip_if_published = skip_if_published ,
150
- freeze_files = freeze_files ,
151
- overwrite = overwrite ,
130
+ download_dir = epath .Path (args .download_dir ) if args .download_dir else None ,
131
+ publish_dir = epath .Path (args .publish_dir ) if args .publish_dir else None ,
132
+ skip_if_published = args .skip_if_published ,
133
+ overwrite = args .overwrite ,
152
134
)
0 commit comments