Skip to content

Commit b7007e5

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Use simple_parsing for CLI flag parsing.
PiperOrigin-RevId: 791568554
1 parent 5b12259 commit b7007e5

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

tensorflow_datasets/scripts/cli/croissant.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,11 @@ def version(self) -> version_lib.Version:
125125

126126
def register_subparser(parsers: argparse._SubParsersAction):
127127
"""Add subparser for `convert_format` command."""
128-
orig_parser_class = parsers._parser_class # pylint: disable=protected-access
129-
try:
130-
parsers._parser_class = simple_parsing.ArgumentParser # pylint: disable=protected-access
131-
parser = parsers.add_parser(
132-
'build_croissant',
133-
help='Prepares a croissant dataset',
134-
)
135-
parser = typing.cast(simple_parsing.ArgumentParser, parser)
136-
finally:
137-
parsers._parser_class = orig_parser_class # pylint: disable=protected-access
128+
parser = parsers.add_parser(
129+
'build_croissant',
130+
help='Prepares a croissant dataset',
131+
)
132+
parser = typing.cast(simple_parsing.ArgumentParser, parser)
138133
parser.add_arguments(CmdArgs, dest='args')
139134
parser.set_defaults(
140135
subparser_fn=lambda args: prepare_croissant_builders(args.args)

tensorflow_datasets/scripts/cli/main.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from absl import app
2929
from absl import flags
3030
from absl import logging
31-
from absl.flags import argparse_flags
3231

32+
import simple_parsing
3333
import tensorflow_datasets.public_api as tfds
3434

3535
# Import commands
@@ -46,7 +46,7 @@ def _parse_flags(argv: List[str]) -> argparse.Namespace:
4646
"""Command lines flag parsing."""
4747
argv = flag_utils.normalize_flags(argv) # See b/174043007 for context.
4848

49-
parser = argparse_flags.ArgumentParser(
49+
parser = simple_parsing.ArgumentParser(
5050
description='Tensorflow Datasets CLI tool',
5151
allow_abbrev=False,
5252
)
@@ -67,7 +67,22 @@ def _parse_flags(argv: List[str]) -> argparse.Namespace:
6767
new.register_subparser(subparser)
6868
convert_format.register_subparser(subparser)
6969
croissant.register_subparser(subparser)
70-
return parser.parse_args(argv[1:])
70+
71+
namespace, remaining_argv = parser.parse_known_args(argv[1:])
72+
73+
# Manually parse absl flags from the remaining arguments.
74+
try:
75+
# FLAGS requires the program name as the first argument.
76+
positionals = FLAGS(argv[:1] + remaining_argv)
77+
except flags.Error as e:
78+
parser.error(str(e))
79+
80+
# There should be no positional arguments left, as they should have been
81+
# handled by the sub-commands.
82+
if len(positionals) > 1:
83+
parser.error(f"unrecognized arguments: {' '.join(positionals[1:])}")
84+
85+
return namespace
7186

7287

7388
def main(args: argparse.Namespace) -> None:

0 commit comments

Comments
 (0)