|
26 | 26 | ```
|
27 | 27 | """
|
28 | 28 |
|
| 29 | +import dataclasses |
| 30 | + |
29 | 31 | from absl import app
|
30 |
| -from absl import flags |
| 32 | +from etils import eapp |
| 33 | +from etils import epath |
| 34 | +import simple_parsing |
31 | 35 | from tensorflow_datasets.core import file_adapters
|
32 | 36 | from tensorflow_datasets.scripts.cli import croissant
|
33 | 37 |
|
34 | 38 |
|
35 |
| -_JSONLD = flags.DEFINE_string( |
36 |
| - name='jsonld', default=None, help='Path to the JSONLD file.', required=True |
37 |
| -) |
38 |
| -_DATA_DIR = flags.DEFINE_string( |
39 |
| - name='data_dir', |
40 |
| - default=None, |
41 |
| - help='Path where the converted dataset will be stored.', |
42 |
| - required=True, |
43 |
| -) |
44 |
| -_FILE_FORMAT = flags.DEFINE_enum_class( |
45 |
| - name='file_format', |
46 |
| - default=file_adapters.FileFormat.ARRAY_RECORD, |
47 |
| - enum_class=file_adapters.FileFormat, |
48 |
| - help='File format to convert the dataset to.', |
49 |
| -) |
50 |
| -_RECORD_SETS = flags.DEFINE_list( |
51 |
| - name='record_sets', |
52 |
| - default=[], |
53 |
| - help=( |
54 |
| - 'The names of the record sets to generate. Each record set will' |
55 |
| - ' correspond to a separate config. If not specified, it will use all' |
56 |
| - ' the record sets.' |
57 |
| - ), |
58 |
| -) |
59 |
| -_MAPPING = flags.DEFINE_string( |
60 |
| - name='mapping', |
61 |
| - default=None, |
62 |
| - help=( |
63 |
| - 'Mapping filename->filepath as a Python dict[str, str] to handle' |
64 |
| - ' manual downloads. If `document.csv` is the FileObject and you' |
65 |
| - ' downloaded it to `~/Downloads/document.csv`, you can' |
66 |
| - ' specify`--mapping=\'{"document.csv": "~/Downloads/document.csv"}\'' |
67 |
| - ), |
68 |
| -) |
69 |
| -_DOWNLOAD_DIR = flags.DEFINE_string( |
70 |
| - name='download_dir', |
71 |
| - default=None, |
72 |
| - help='Where to place downloads. Default to `<data_dir>/downloads/`.', |
73 |
| -) |
74 |
| -_PUBLISH_DIR = flags.DEFINE_string( |
75 |
| - name='publish_dir', |
76 |
| - default=None, |
77 |
| - help=( |
78 |
| - 'Where to optionally publish the dataset after it has been generated ' |
79 |
| - 'successfully. Should be the root data dir under which datasets are ' |
80 |
| - 'stored. If unspecified, dataset will not be published.' |
81 |
| - ), |
82 |
| -) |
83 |
| -_SKIP_IF_PUBLISHED = flags.DEFINE_bool( |
84 |
| - name='skip_if_published', |
85 |
| - default=False, |
86 |
| - help=( |
87 |
| - 'If the dataset with the same version and config is already published, ' |
88 |
| - 'then it will not be regenerated.' |
89 |
| - ), |
90 |
| -) |
91 |
| -_OVERWRITE = flags.DEFINE_bool( |
92 |
| - name='overwrite', |
93 |
| - default=False, |
94 |
| - help='Delete pre-existing dataset if it exists.', |
95 |
| -) |
| 39 | +@dataclasses.dataclass |
| 40 | +class CmdArgs: |
| 41 | + """CLI arguments for preparing a Croissant dataset. |
| 42 | +
|
| 43 | + Attributes: |
| 44 | + jsonld: Path to the JSONLD file. |
| 45 | + data_dir: Path where the converted dataset will be stored. |
| 46 | + file_format: File format to convert the dataset to. |
| 47 | + record_sets: The names of the record sets to generate. Each record set will |
| 48 | + correspond to a separate config. If not specified, it will use all the |
| 49 | + record sets. |
| 50 | + mapping: Mapping filename->filepath as a Python dict[str, str] to handle |
| 51 | + manual downloads. If `document.csv` is the FileObject and you downloaded |
| 52 | + it to `~/Downloads/document.csv`, you can specify |
| 53 | + `--mapping='{"document.csv": "~/Downloads/document.csv"}'` |
| 54 | + download_dir: Where to place downloads. Default to `<data_dir>/downloads/`. |
| 55 | + publish_dir: Where to optionally publish the dataset after it has been |
| 56 | + generated successfully. Should be the root data dir under which datasets |
| 57 | + are stored. If unspecified, dataset will not be published. |
| 58 | + skip_if_published: If the dataset with the same version and config is |
| 59 | + already published, then it will not be regenerated. |
| 60 | + overwrite: Delete pre-existing dataset if it exists. |
| 61 | + """ |
| 62 | + |
| 63 | + jsonld: epath.PathLike |
| 64 | + data_dir: epath.PathLike |
| 65 | + # Need to override the default use of `Enum.name` for choice options. |
| 66 | + file_format: str = simple_parsing.choice( |
| 67 | + *(file_format.value for file_format in file_adapters.FileFormat), |
| 68 | + default=file_adapters.FileFormat.ARRAY_RECORD.value, |
| 69 | + ) |
| 70 | + # Need to manually parse comma-separated list of values, see: |
| 71 | + # https://github.com/lebrice/SimpleParsing/issues/142. |
| 72 | + record_sets: list[str] = simple_parsing.field( |
| 73 | + default_factory=list, |
| 74 | + type=lambda record_sets_str: record_sets_str.split(','), |
| 75 | + nargs='?', |
| 76 | + ) |
| 77 | + mapping: str | None = None |
| 78 | + download_dir: epath.PathLike | None = None |
| 79 | + publish_dir: epath.PathLike | None = None |
| 80 | + skip_if_published: bool = False |
| 81 | + overwrite: bool = False |
| 82 | + |
| 83 | +parse_flags = eapp.make_flags_parser(CmdArgs) |
96 | 84 |
|
97 | 85 |
|
98 |
| -def main(_): |
| 86 | +def main(args: CmdArgs): |
99 | 87 | croissant.prepare_croissant_builder(
|
100 |
| - jsonld=_JSONLD.value, |
101 |
| - data_dir=_DATA_DIR.value, |
102 |
| - file_format=_FILE_FORMAT.value.value, |
103 |
| - record_sets=_RECORD_SETS.value, |
104 |
| - mapping=_MAPPING.value, |
105 |
| - download_dir=_DOWNLOAD_DIR.value, |
106 |
| - publish_dir=_PUBLISH_DIR.value, |
107 |
| - skip_if_published=_SKIP_IF_PUBLISHED.value, |
108 |
| - overwrite=_OVERWRITE.value, |
| 88 | + jsonld=args.jsonld, |
| 89 | + data_dir=args.data_dir, |
| 90 | + file_format=args.file_format, |
| 91 | + record_sets=args.record_sets, |
| 92 | + mapping=args.mapping, |
| 93 | + download_dir=args.download_dir, |
| 94 | + publish_dir=args.publish_dir, |
| 95 | + skip_if_published=args.skip_if_published, |
| 96 | + overwrite=args.overwrite, |
109 | 97 | )
|
110 | 98 |
|
111 | 99 |
|
112 | 100 | if __name__ == '__main__':
|
113 |
| - app.run(main) |
| 101 | + app.run(main, flags_parser=parse_flags) |
0 commit comments