Skip to content

Commit 502c65f

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Fix parsing CLI arguments and ABSL flags.
PiperOrigin-RevId: 792554983
1 parent 3a34697 commit 502c65f

File tree

4 files changed

+56
-26
lines changed

4 files changed

+56
-26
lines changed

tensorflow_datasets/scripts/cli/cli_utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,65 @@
1616
"""Utility functions for TFDS CLI."""
1717

1818
import argparse
19+
from collections.abc import Sequence
1920
import dataclasses
2021
import itertools
2122
import os
2223
import pathlib
2324

2425
from absl import logging
26+
from absl.flags import argparse_flags
2527
from etils import epath
28+
import simple_parsing
2629
from tensorflow_datasets.core import dataset_builder
2730
from tensorflow_datasets.core import download
2831
from tensorflow_datasets.core import file_adapters
2932
from tensorflow_datasets.core import naming
3033
from tensorflow_datasets.core.utils import file_utils
34+
from tensorflow_datasets.scripts.utils import flag_utils
35+
36+
37+
class ArgumentParser(
38+
argparse_flags.ArgumentParser, simple_parsing.ArgumentParser
39+
):
40+
"""An `ArgumentParser` that handles both `simple_parsing` and `absl` flags.
41+
42+
This class is a workaround for the fact that `simple_parsing.ArgumentParser`
43+
does not natively handle `absl.flags`. Without this, `absl` flags are not
44+
correctly parsed, especially when they are mixed with positional arguments,
45+
leading to errors.
46+
47+
The `absl.flags.argparse_flags.ArgumentParser` is designed to integrate `absl`
48+
flags into an `argparse` setup. It does this by dynamically adding all
49+
defined `absl` flags to the parser instance upon initialization.
50+
51+
By inheriting from both, we get the features of both:
52+
- `simple_parsing.ArgumentParser`: Allows defining arguments from typed
53+
dataclasses.
54+
- `argparse_flags.ArgumentParser`: Adds support for `absl` flags.
55+
56+
The Method Resolution Order (MRO) is:
57+
`ArgumentParser` -> `argparse_flags.ArgumentParser` ->
58+
`simple_parsing.ArgumentParser` -> `argparse.ArgumentParser` -> `object`.
59+
60+
This order is important. `argparse_flags.ArgumentParser` is first so that it
61+
can intercept arguments and handle `absl` flags before they are passed to
62+
`simple_parsing.ArgumentParser`.
63+
"""
64+
65+
def parse_known_args(
66+
self,
67+
args: Sequence[str] | None = None,
68+
namespace: argparse.Namespace | None = None,
69+
attempt_to_reorder: bool = False,
70+
):
71+
# `argparse_flags.ArgumentParser` does not support `attempt_to_reorder` that
72+
# is used by `simple_parsing.ArgumentParser`. Since we don't need it, we can
73+
# just ignore it.
74+
del attempt_to_reorder
75+
if args:
76+
args = flag_utils.normalize_flags(args)
77+
return super().parse_known_args(args, namespace)
3178

3279

3380
@dataclasses.dataclass

tensorflow_datasets/scripts/cli/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515

1616
"""CLI Fixtures."""
1717

18-
import argparse
1918
from unittest import mock
2019

2120
import pytest
21+
import simple_parsing
22+
from tensorflow_datasets.scripts.cli import cli_utils
2223

2324

2425
@pytest.fixture(scope='session', autouse=True)
@@ -29,7 +30,7 @@ def _mock_argparse_flags():
2930
# another test):
3031
# `flags.DEFINE_string('data_dir')` with `parser.add_argument('--data_dir')`
3132
# We patch argparse_flags during test, so absl flags are ignored.
32-
with mock.patch(
33-
'absl.flags.argparse_flags.ArgumentParser', argparse.ArgumentParser
33+
with mock.patch.object(
34+
cli_utils, 'ArgumentParser', simple_parsing.ArgumentParser
3435
):
3536
yield

tensorflow_datasets/scripts/cli/main.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,21 @@
2929
from absl import flags
3030
from absl import logging
3131

32-
import simple_parsing
3332
import tensorflow_datasets.public_api as tfds
3433

3534
# Import commands
3635
from tensorflow_datasets.scripts.cli import build
36+
from tensorflow_datasets.scripts.cli import cli_utils
3737
from tensorflow_datasets.scripts.cli import convert_format
3838
from tensorflow_datasets.scripts.cli import croissant
3939
from tensorflow_datasets.scripts.cli import new
40-
from tensorflow_datasets.scripts.utils import flag_utils
4140

4241
FLAGS = flags.FLAGS
4342

4443

4544
def _parse_flags(argv: List[str]) -> argparse.Namespace:
4645
"""Command lines flag parsing."""
47-
argv = flag_utils.normalize_flags(argv) # See b/174043007 for context.
48-
49-
parser = simple_parsing.ArgumentParser(
46+
parser = cli_utils.ArgumentParser(
5047
description='Tensorflow Datasets CLI tool',
5148
allow_abbrev=False,
5249
)
@@ -67,22 +64,7 @@ def _parse_flags(argv: List[str]) -> argparse.Namespace:
6764
new.register_subparser(subparser)
6865
convert_format.register_subparser(subparser)
6966
croissant.register_subparser(subparser)
70-
71-
namespace, remaining_argv = parser.parse_known_args(argv[1:])
72-
73-
# Manually parse absl flags from the remaining arguments.
74-
try:
75-
# FLAGS requires the program name as the first argument.
76-
positionals = FLAGS(argv[:1] + remaining_argv)
77-
except flags.Error as e:
78-
parser.error(str(e))
79-
80-
# There should be no positional arguments left, as they should have been
81-
# handled by the sub-commands.
82-
if len(positionals) > 1:
83-
parser.error(f"unrecognized arguments: {' '.join(positionals[1:])}")
84-
85-
return namespace
67+
return parser.parse_args(argv[1:])
8668

8769

8870
def main(args: argparse.Namespace) -> None:

tensorflow_datasets/scripts/utils/flag_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515

1616
"""Utility for handling flags."""
1717

18+
from collections.abc import Sequence
1819
import re
19-
from typing import List
2020

2121

22-
def normalize_flags(argv: List[str]) -> List[str]:
22+
def normalize_flags(argv: Sequence[str]) -> list[str]:
2323
"""Returns normalized explicit bolean flags for `absl.flags` compatibility.
2424
2525
Note: Boolean flags in `absl.flags` can be specified with --bool, --nobool,

0 commit comments

Comments
 (0)