Skip to content

Commit a586b8e

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Use dataclasses for CLI arguments in scripts/cli/croissant.py.
PiperOrigin-RevId: 633542690
1 parent 90a6846 commit a586b8e

File tree

3 files changed

+76
-156
lines changed

3 files changed

+76
-156
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
'psutil',
7474
'pyarrow',
7575
'requests>=2.19.0',
76+
'simple_parsing',
7677
'tensorflow-metadata',
7778
'termcolor',
7879
'toml',

tensorflow_datasets/scripts/cli/croissant.py

Lines changed: 73 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -21,101 +21,38 @@
2121
--jsonld=/tmp/croissant.json \
2222
--data_dir=/tmp/foo \
2323
--file_format=array_record \
24-
--record_sets=record1 --record_sets=record2 \
24+
--record_sets=record1,record2 \
2525
--mapping='{"document.csv": "~/Downloads/document.csv"}"'
2626
```
2727
"""
2828

2929
import argparse
30-
from collections.abc import Sequence
30+
import dataclasses
3131
import json
32+
import typing
3233

3334
from etils import epath
35+
import simple_parsing
36+
from tensorflow_datasets.core import file_adapters
3437
from tensorflow_datasets.core.dataset_builders import croissant_builder
3538
from tensorflow_datasets.scripts.cli import cli_utils
3639

3740

38-
def add_parser_arguments(parser: argparse.ArgumentParser):
39-
"""Add arguments for `build_croissant` subparser."""
40-
parser.add_argument(
41-
'--jsonld',
42-
type=str,
43-
help='The Croissant config file for the given dataset.',
44-
required=True,
45-
)
46-
parser.add_argument(
47-
'--record_sets',
48-
nargs='*',
49-
help=(
50-
'The names of the record sets to generate. Each record set will'
51-
' correspond to a separate config. If not specified, it will use all'
52-
' the record sets'
53-
),
54-
)
55-
parser.add_argument(
56-
'--mapping',
57-
type=str,
58-
help=(
59-
'Mapping filename->filepath as a Python dict[str, str] to handle'
60-
' manual downloads. If `document.csv` is the FileObject and you'
61-
' downloaded it to `~/Downloads/document.csv`, you can'
62-
' specify`--mapping=\'{"document.csv": "~/Downloads/document.csv"}\''
63-
),
64-
)
65-
66-
cli_utils.add_debug_argument_group(parser)
67-
cli_utils.add_path_argument_group(parser)
68-
cli_utils.add_generation_argument_group(parser)
69-
cli_utils.add_publish_argument_group(parser)
41+
@dataclasses.dataclass
42+
class CmdArgs:
43+
"""CLI arguments for preparing a Croissant dataset.
7044
71-
72-
def register_subparser(parsers: argparse._SubParsersAction):
73-
"""Add subparser for `convert_format` command."""
74-
parser = parsers.add_parser(
75-
'build_croissant',
76-
help='Prepares a croissant dataset',
77-
)
78-
add_parser_arguments(parser)
79-
parser.set_defaults(
80-
subparser_fn=lambda args: prepare_croissant_builder(
81-
jsonld=args.jsonld,
82-
data_dir=args.data_dir,
83-
file_format=args.file_format,
84-
record_sets=args.record_sets,
85-
mapping=args.mapping,
86-
download_dir=args.download_dir,
87-
publish_dir=args.publish_dir,
88-
skip_if_published=args.skip_if_published,
89-
overwrite=args.overwrite,
90-
)
91-
)
92-
93-
94-
def prepare_croissant_builder(
95-
jsonld: epath.PathLike,
96-
data_dir: epath.PathLike,
97-
file_format: str,
98-
record_sets: Sequence[str],
99-
mapping: str | None,
100-
download_dir: epath.PathLike | None,
101-
publish_dir: epath.PathLike | None,
102-
skip_if_published: bool,
103-
overwrite: bool,
104-
) -> None:
105-
# pyformat: disable
106-
"""Creates a Croissant Builder and runs the preparation.
107-
108-
Args:
109-
jsonld: The Croissant config file for the given dataset
45+
Attributes:
46+
jsonld: Path to the JSONLD file.
11047
data_dir: Path where the converted dataset will be stored.
11148
file_format: File format to convert the dataset to.
112-
record_sets: The `@id`s of the record sets to generate. Each record set will
49+
record_sets: The names of the record sets to generate. Each record set will
11350
correspond to a separate config. If not specified, it will use all the
114-
record sets
51+
record sets.
11552
mapping: Mapping filename->filepath as a Python dict[str, str] to handle
11653
manual downloads. If `document.csv` is the FileObject and you downloaded
11754
it to `~/Downloads/document.csv`, you can specify
118-
`mapping={"document.csv": "~/Downloads/document.csv"}`.,
55+
`--mapping='{"document.csv": "~/Downloads/document.csv"}'`
11956
download_dir: Where to place downloads. Default to `<data_dir>/downloads/`.
12057
publish_dir: Where to optionally publish the dataset after it has been
12158
generated successfully. Should be the root data dir under which datasets
@@ -124,29 +61,74 @@ def prepare_croissant_builder(
12461
already published, then it will not be regenerated.
12562
overwrite: Delete pre-existing dataset if it exists.
12663
"""
127-
# pyformat: enable
128-
if not record_sets:
129-
record_sets = None
13064

131-
if mapping:
65+
jsonld: epath.PathLike
66+
data_dir: epath.PathLike
67+
# Need to override the default use of `Enum.name` for choice options.
68+
file_format: str = simple_parsing.choice(
69+
*(file_format.value for file_format in file_adapters.FileFormat),
70+
default=file_adapters.FileFormat.ARRAY_RECORD.value,
71+
)
72+
# Need to manually parse comma-separated list of values, see:
73+
# https://github.com/lebrice/SimpleParsing/issues/142.
74+
record_sets: list[str] = simple_parsing.field(
75+
default_factory=list,
76+
type=lambda record_sets_str: record_sets_str.split(','),
77+
nargs='?',
78+
)
79+
mapping: str | None = None
80+
download_dir: epath.PathLike | None = None
81+
publish_dir: epath.PathLike | None = None
82+
skip_if_published: bool = False
83+
overwrite: bool = False
84+
85+
86+
def register_subparser(parsers: argparse._SubParsersAction):
87+
"""Add subparser for `convert_format` command."""
88+
orig_parser_class = parsers._parser_class # pylint: disable=protected-access
89+
try:
90+
parsers._parser_class = simple_parsing.ArgumentParser # pylint: disable=protected-access
91+
parser = parsers.add_parser(
92+
'build_croissant',
93+
help='Prepares a croissant dataset',
94+
)
95+
parser = typing.cast(simple_parsing.ArgumentParser, parser)
96+
finally:
97+
parsers._parser_class = orig_parser_class # pylint: disable=protected-access
98+
parser.add_arguments(CmdArgs, dest='args')
99+
parser.set_defaults(
100+
subparser_fn=lambda args: prepare_croissant_builder(args.args)
101+
)
102+
103+
104+
def prepare_croissant_builder(args: CmdArgs) -> None:
105+
"""Creates a Croissant Builder and runs the preparation.
106+
107+
Args:
108+
args: CLI arguments.
109+
"""
110+
if args.mapping:
132111
try:
133-
mapping = json.loads(mapping)
112+
mapping = json.loads(args.mapping)
134113
except json.JSONDecodeError as e:
135-
raise ValueError(f'Error parsing mapping parameter: {mapping}') from e
114+
raise ValueError(
115+
f'Error parsing mapping parameter: {args.mapping}'
116+
) from e
117+
else:
118+
mapping = None
136119

137120
builder = croissant_builder.CroissantBuilder(
138-
jsonld=jsonld,
139-
record_set_ids=record_sets,
140-
file_format=file_format,
141-
data_dir=data_dir,
121+
jsonld=args.jsonld,
122+
record_set_ids=args.record_sets or None,
123+
file_format=args.file_format,
124+
data_dir=args.data_dir,
142125
mapping=mapping,
143126
)
144127
cli_utils.download_and_prepare(
145128
builder=builder,
146129
download_config=None,
147-
download_dir=epath.Path(download_dir) if download_dir else None,
148-
publish_dir=epath.Path(publish_dir) if publish_dir else None,
149-
skip_if_published=skip_if_published,
150-
freeze_files=freeze_files,
151-
overwrite=overwrite,
130+
download_dir=epath.Path(args.download_dir) if args.download_dir else None,
131+
publish_dir=epath.Path(args.publish_dir) if args.publish_dir else None,
132+
skip_if_published=args.skip_if_published,
133+
overwrite=args.overwrite,
152134
)

tensorflow_datasets/scripts/prepare_croissant.py

Lines changed: 2 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -26,76 +26,13 @@
2626
```
2727
"""
2828

29-
import dataclasses
30-
3129
from absl import app
3230
from etils import eapp
33-
from etils import epath
34-
import simple_parsing
35-
from tensorflow_datasets.core import file_adapters
3631
from tensorflow_datasets.scripts.cli import croissant
3732

3833

39-
@dataclasses.dataclass
40-
class CmdArgs:
41-
"""CLI arguments for preparing a Croissant dataset.
42-
43-
Attributes:
44-
jsonld: Path to the JSONLD file.
45-
data_dir: Path where the converted dataset will be stored.
46-
file_format: File format to convert the dataset to.
47-
record_sets: The names of the record sets to generate. Each record set will
48-
correspond to a separate config. If not specified, it will use all the
49-
record sets.
50-
mapping: Mapping filename->filepath as a Python dict[str, str] to handle
51-
manual downloads. If `document.csv` is the FileObject and you downloaded
52-
it to `~/Downloads/document.csv`, you can specify
53-
`--mapping='{"document.csv": "~/Downloads/document.csv"}'`
54-
download_dir: Where to place downloads. Default to `<data_dir>/downloads/`.
55-
publish_dir: Where to optionally publish the dataset after it has been
56-
generated successfully. Should be the root data dir under which datasets
57-
are stored. If unspecified, dataset will not be published.
58-
skip_if_published: If the dataset with the same version and config is
59-
already published, then it will not be regenerated.
60-
overwrite: Delete pre-existing dataset if it exists.
61-
"""
62-
63-
jsonld: epath.PathLike
64-
data_dir: epath.PathLike
65-
# Need to override the default use of `Enum.name` for choice options.
66-
file_format: str = simple_parsing.choice(
67-
*(file_format.value for file_format in file_adapters.FileFormat),
68-
default=file_adapters.FileFormat.ARRAY_RECORD.value,
69-
)
70-
# Need to manually parse comma-separated list of values, see:
71-
# https://github.com/lebrice/SimpleParsing/issues/142.
72-
record_sets: list[str] = simple_parsing.field(
73-
default_factory=list,
74-
type=lambda record_sets_str: record_sets_str.split(','),
75-
nargs='?',
76-
)
77-
mapping: str | None = None
78-
download_dir: epath.PathLike | None = None
79-
publish_dir: epath.PathLike | None = None
80-
skip_if_published: bool = False
81-
overwrite: bool = False
82-
83-
parse_flags = eapp.make_flags_parser(CmdArgs)
84-
85-
86-
def main(args: CmdArgs):
87-
croissant.prepare_croissant_builder(
88-
jsonld=args.jsonld,
89-
data_dir=args.data_dir,
90-
file_format=args.file_format,
91-
record_sets=args.record_sets,
92-
mapping=args.mapping,
93-
download_dir=args.download_dir,
94-
publish_dir=args.publish_dir,
95-
skip_if_published=args.skip_if_published,
96-
overwrite=args.overwrite,
97-
)
34+
parse_flags = eapp.make_flags_parser(croissant.CmdArgs)
9835

9936

10037
if __name__ == '__main__':
101-
app.run(main, flags_parser=parse_flags)
38+
app.run(croissant.prepare_croissant_builder, flags_parser=parse_flags)

0 commit comments

Comments
 (0)