Skip to content

Commit df931fb

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Use etils.eapp for prepare_croissant.py
PiperOrigin-RevId: 633316246
1 parent 548f2c4 commit df931fb

File tree

1 file changed

+61
-73
lines changed

1 file changed

+61
-73
lines changed

tensorflow_datasets/scripts/prepare_croissant.py

Lines changed: 61 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -26,88 +26,76 @@
2626
```
2727
"""
2828

29+
import dataclasses
30+
2931
from absl import app
30-
from absl import flags
32+
from etils import eapp
33+
from etils import epath
34+
import simple_parsing
3135
from tensorflow_datasets.core import file_adapters
3236
from tensorflow_datasets.scripts.cli import croissant
3337

3438

35-
_JSONLD = flags.DEFINE_string(
36-
name='jsonld', default=None, help='Path to the JSONLD file.', required=True
37-
)
38-
_DATA_DIR = flags.DEFINE_string(
39-
name='data_dir',
40-
default=None,
41-
help='Path where the converted dataset will be stored.',
42-
required=True,
43-
)
44-
_FILE_FORMAT = flags.DEFINE_enum_class(
45-
name='file_format',
46-
default=file_adapters.FileFormat.ARRAY_RECORD,
47-
enum_class=file_adapters.FileFormat,
48-
help='File format to convert the dataset to.',
49-
)
50-
_RECORD_SETS = flags.DEFINE_list(
51-
name='record_sets',
52-
default=[],
53-
help=(
54-
'The names of the record sets to generate. Each record set will'
55-
' correspond to a separate config. If not specified, it will use all'
56-
' the record sets.'
57-
),
58-
)
59-
_MAPPING = flags.DEFINE_string(
60-
name='mapping',
61-
default=None,
62-
help=(
63-
'Mapping filename->filepath as a Python dict[str, str] to handle'
64-
' manual downloads. If `document.csv` is the FileObject and you'
65-
' downloaded it to `~/Downloads/document.csv`, you can'
66-
' specify`--mapping=\'{"document.csv": "~/Downloads/document.csv"}\''
67-
),
68-
)
69-
_DOWNLOAD_DIR = flags.DEFINE_string(
70-
name='download_dir',
71-
default=None,
72-
help='Where to place downloads. Default to `<data_dir>/downloads/`.',
73-
)
74-
_PUBLISH_DIR = flags.DEFINE_string(
75-
name='publish_dir',
76-
default=None,
77-
help=(
78-
'Where to optionally publish the dataset after it has been generated '
79-
'successfully. Should be the root data dir under which datasets are '
80-
'stored. If unspecified, dataset will not be published.'
81-
),
82-
)
83-
_SKIP_IF_PUBLISHED = flags.DEFINE_bool(
84-
name='skip_if_published',
85-
default=False,
86-
help=(
87-
'If the dataset with the same version and config is already published, '
88-
'then it will not be regenerated.'
89-
),
90-
)
91-
_OVERWRITE = flags.DEFINE_bool(
92-
name='overwrite',
93-
default=False,
94-
help='Delete pre-existing dataset if it exists.',
95-
)
39+
@dataclasses.dataclass
40+
class CmdArgs:
41+
"""CLI arguments for preparing a Croissant dataset.
42+
43+
Attributes:
44+
jsonld: Path to the JSONLD file.
45+
data_dir: Path where the converted dataset will be stored.
46+
file_format: File format to convert the dataset to.
47+
record_sets: The names of the record sets to generate. Each record set will
48+
correspond to a separate config. If not specified, it will use all the
49+
record sets.
50+
mapping: Mapping filename->filepath as a Python dict[str, str] to handle
51+
manual downloads. If `document.csv` is the FileObject and you downloaded
52+
it to `~/Downloads/document.csv`, you can specify
53+
`--mapping='{"document.csv": "~/Downloads/document.csv"}'`
54+
download_dir: Where to place downloads. Default to `<data_dir>/downloads/`.
55+
publish_dir: Where to optionally publish the dataset after it has been
56+
generated successfully. Should be the root data dir under which datasets
57+
are stored. If unspecified, dataset will not be published.
58+
skip_if_published: If the dataset with the same version and config is
59+
already published, then it will not be regenerated.
60+
overwrite: Delete pre-existing dataset if it exists.
61+
"""
62+
63+
jsonld: epath.PathLike
64+
data_dir: epath.PathLike
65+
# Need to override the default use of `Enum.name` for choice options.
66+
file_format: str = simple_parsing.choice(
67+
*(file_format.value for file_format in file_adapters.FileFormat),
68+
default=file_adapters.FileFormat.ARRAY_RECORD.value,
69+
)
70+
# Need to manually parse comma-separated list of values, see:
71+
# https://github.com/lebrice/SimpleParsing/issues/142.
72+
record_sets: list[str] = simple_parsing.field(
73+
default_factory=list,
74+
type=lambda record_sets_str: record_sets_str.split(','),
75+
nargs='?',
76+
)
77+
mapping: str | None = None
78+
download_dir: epath.PathLike | None = None
79+
publish_dir: epath.PathLike | None = None
80+
skip_if_published: bool = False
81+
overwrite: bool = False
82+
83+
parse_flags = eapp.make_flags_parser(CmdArgs)
9684

9785

98-
def main(_):
86+
def main(args: CmdArgs):
9987
croissant.prepare_croissant_builder(
100-
jsonld=_JSONLD.value,
101-
data_dir=_DATA_DIR.value,
102-
file_format=_FILE_FORMAT.value.value,
103-
record_sets=_RECORD_SETS.value,
104-
mapping=_MAPPING.value,
105-
download_dir=_DOWNLOAD_DIR.value,
106-
publish_dir=_PUBLISH_DIR.value,
107-
skip_if_published=_SKIP_IF_PUBLISHED.value,
108-
overwrite=_OVERWRITE.value,
88+
jsonld=args.jsonld,
89+
data_dir=args.data_dir,
90+
file_format=args.file_format,
91+
record_sets=args.record_sets,
92+
mapping=args.mapping,
93+
download_dir=args.download_dir,
94+
publish_dir=args.publish_dir,
95+
skip_if_published=args.skip_if_published,
96+
overwrite=args.overwrite,
10997
)
11098

11199

112100
if __name__ == '__main__':
113-
app.run(main)
101+
app.run(main, flags_parser=parse_flags)

0 commit comments

Comments
 (0)