Skip to content

Commit 731969b

Browse files
authored
feat: ✨ automatically replace invalid enum expressions with corresponding valid expression & import (#196)
1 parent c0f24f9 commit 731969b

File tree

35 files changed

+451
-16
lines changed

35 files changed

+451
-16
lines changed

pybind11_stubgen/parser/errors.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from pybind11_stubgen.structs import Identifier, QualifiedName
3+
from pybind11_stubgen.structs import Identifier, Import, QualifiedName, Value
44

55

66
class ParserError(Exception):
@@ -33,3 +33,23 @@ def __init__(self, name: QualifiedName):
3333

3434
def __str__(self):
3535
return f"Can't find/import '{self.name}'"
36+
37+
38+
class AmbiguousEnumError(InvalidExpressionError):
39+
def __init__(self, repr_: str, *values_and_imports: tuple[Value, Import]):
40+
super().__init__(repr_)
41+
self.values_and_imports = values_and_imports
42+
43+
if len(self.values_and_imports) < 2:
44+
raise ValueError(
45+
"Expected at least 2 values_and_imports, got "
46+
f"{len(self.values_and_imports)}"
47+
)
48+
49+
def __str__(self) -> str:
50+
origins = sorted(import_.origin for _, import_ in self.values_and_imports)
51+
return (
52+
f"Enum member '{self.expression}' could not be resolved; multiple "
53+
"matching definitions found in: "
54+
+ ", ".join(f"'{origin}'" for origin in origins)
55+
)

pybind11_stubgen/parser/mixins/error_handlers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def handle_method(self, path: QualifiedName, class_: type) -> list[Method]:
5656
with self.__new_layer(path):
5757
return super().handle_method(path, class_)
5858

59+
def finalize(self) -> None:
60+
with self.__new_layer(QualifiedName.from_str("finalize")):
61+
return super().finalize()
62+
5963
@property
6064
def current_path(self) -> QualifiedName:
6165
assert len(self.stack) != 0

pybind11_stubgen/parser/mixins/fix.py

Lines changed: 137 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
import re
77
import sys
88
import types
9+
from collections import defaultdict
910
from logging import getLogger
1011
from typing import Any, Sequence
1112

1213
from pybind11_stubgen.parser.errors import (
14+
AmbiguousEnumError,
1315
InvalidExpressionError,
1416
NameResolutionError,
1517
ParserError,
@@ -999,13 +1001,56 @@ def parse_value_str(self, value: str) -> Value | InvalidExpression:
9991001

10001002

10011003
class RewritePybind11EnumValueRepr(IParser):
1004+
"""Reformat pybind11-generated invalid enum value reprs.
1005+
1006+
For example, pybind11 may generate a `__doc__` like this:
1007+
>>> "set_color(self, color: <ConsoleForegroundColor.Blue: 34>) -> None:\n"
1008+
1009+
Which is invalid python syntax. This parser will rewrite the generated stub to:
1010+
>>> from demo._bindings.enum import ConsoleForegroundColor
1011+
>>> def set_color(self, color: ConsoleForegroundColor.Blue) -> None:
1012+
>>> ...
1013+
1014+
Since `pybind11_stubgen` encounters the values corresponding to these reprs as it
1015+
parses the modules, it can automatically replace these invalid expressions with the
1016+
corresponding `Value` and `Import` as it encounters them. There are 3 cases for an
1017+
`Argument` whose `default` is an enum `InvalidExpression`:
1018+
1019+
1. The `InvalidExpression` repr corresponds to exactly one enum field definition.
1020+
The `InvalidExpression` is simply replaced by the corresponding `Value`.
1021+
2. The `InvalidExpression` repr corresponds to multiple enum field definitions. An
1022+
`AmbiguousEnumError` is reported.
1023+
3. The `InvalidExpression` repr corresponds to no enum field definitions. An
1024+
`InvalidExpressionError` is reported.
1025+
1026+
Attributes:
1027+
_pybind11_enum_pattern: Pattern matching pybind11 enum field reprs.
1028+
_unknown_enum_classes: Set of the str names of enum classes whose reprs were not
1029+
seen.
1030+
_invalid_default_arguments: Per module invalid arguments. Used to know which
1031+
enum imports to add to the current module.
1032+
_repr_to_value_and_import: Saved safe print values of enum field reprs and the
1033+
import to add to a module when when that repr is seen.
1034+
_repr_to_invalid_default_arguments: Groups of arguments whose default values are
1035+
`InvalidExpression`s. This is only used until the first time each repr is
1036+
seen. Left over groups will raise an error, which may be fixed using
1037+
`--enum-class-locations` or suppressed using `--ignore-invalid-expressions`.
1038+
_invalid_default_argument_to_module: Maps individual invalid default arguments
1039+
to the module containing them. Used to know which enum imports to add to
1040+
which module.
1041+
"""
1042+
10021043
_pybind11_enum_pattern = re.compile(r"<(?P<enum>\w+(\.\w+)+): (?P<value>-?\d+)>")
1003-
# _pybind11_enum_pattern = re.compile(r"<(?P<enum>\w+(\.\w+)+): (?P<value>\d+)>")
10041044
_unknown_enum_classes: set[str] = set()
1045+
_invalid_default_arguments: list[Argument] = []
1046+
_repr_to_value_and_import: dict[str, set[tuple[Value, Import]]] = defaultdict(set)
1047+
_repr_to_invalid_default_arguments: dict[str, set[Argument]] = defaultdict(set)
1048+
_invalid_default_argument_to_module: dict[Argument, Module] = {}
10051049

10061050
def __init__(self):
10071051
super().__init__()
10081052
self._pybind11_enum_locations: dict[re.Pattern, str] = {}
1053+
self._is_finalizing = False
10091054

10101055
def set_pybind11_enum_locations(self, locations: dict[re.Pattern, str]):
10111056
self._pybind11_enum_locations = locations
@@ -1024,17 +1069,104 @@ def parse_value_str(self, value: str) -> Value | InvalidExpression:
10241069
return Value(repr=f"{enum_class.name}.{entry}", is_print_safe=True)
10251070
return super().parse_value_str(value)
10261071

1072+
def handle_module(
1073+
self, path: QualifiedName, module: types.ModuleType
1074+
) -> Module | None:
1075+
# we may be handling a module within a module, so save the parent's invalid
1076+
# arguments on the stack as we handle this module
1077+
parent_module_invalid_arguments = self._invalid_default_arguments
1078+
self._invalid_default_arguments = []
1079+
result = super().handle_module(path, module)
1080+
1081+
if result is None:
1082+
self._invalid_default_arguments = parent_module_invalid_arguments
1083+
return None
1084+
1085+
# register each argument to the current module
1086+
while self._invalid_default_arguments:
1087+
arg = self._invalid_default_arguments.pop()
1088+
assert isinstance(arg.default, InvalidExpression)
1089+
repr_ = arg.default.text
1090+
self._repr_to_invalid_default_arguments[repr_].add(arg)
1091+
self._invalid_default_argument_to_module[arg] = result
1092+
1093+
self._invalid_default_arguments = parent_module_invalid_arguments
1094+
return result
1095+
1096+
def handle_function(self, path: QualifiedName, func: Any) -> list[Function]:
1097+
result = super().handle_function(path, func)
1098+
1099+
for f in result:
1100+
for arg in f.args:
1101+
if isinstance(arg.default, InvalidExpression):
1102+
# this argument will be registered to the current module
1103+
self._invalid_default_arguments.append(arg)
1104+
1105+
return result
1106+
1107+
def handle_attribute(self, path: QualifiedName, attr: Any) -> Attribute | None:
1108+
module = inspect.getmodule(attr)
1109+
repr_ = repr(attr)
1110+
1111+
if module is not None:
1112+
module_path = QualifiedName.from_str(module.__name__)
1113+
is_source_module = path[: len(module_path)] == module_path
1114+
is_alias = ( # could be an `.export_values()` alias, which we want to avoid
1115+
is_source_module
1116+
and not inspect.isclass(getattr(module, path[len(module_path)]))
1117+
)
1118+
1119+
if not is_alias and is_source_module:
1120+
# register one of the possible sources of this repr
1121+
self._repr_to_value_and_import[repr_].add(
1122+
(
1123+
Value(repr=".".join(path), is_print_safe=True),
1124+
Import(name=None, origin=module_path),
1125+
)
1126+
)
1127+
1128+
return super().handle_attribute(path, attr)
1129+
10271130
def report_error(self, error: ParserError) -> None:
1028-
if isinstance(error, InvalidExpressionError):
1131+
# defer reporting invalid enum expressions until finalization
1132+
if not self._is_finalizing and isinstance(error, InvalidExpressionError):
10291133
match = self._pybind11_enum_pattern.match(error.expression)
10301134
if match is not None:
1135+
return
1136+
super().report_error(error)
1137+
1138+
def finalize(self) -> None:
1139+
self._is_finalizing = True
1140+
for repr_, args in self._repr_to_invalid_default_arguments.items():
1141+
match = self._pybind11_enum_pattern.match(repr_)
1142+
if match is None:
1143+
pass
1144+
elif repr_ not in self._repr_to_value_and_import:
10311145
enum_qual_name = match.group("enum")
1032-
enum_class_str, entry = enum_qual_name.rsplit(".", maxsplit=1)
1146+
enum_class_str, _ = enum_qual_name.rsplit(".", maxsplit=1)
10331147
self._unknown_enum_classes.add(enum_class_str)
1034-
super().report_error(error)
1148+
self.report_error(InvalidExpressionError(repr_))
1149+
elif len(self._repr_to_value_and_import[repr_]) > 1:
1150+
self.report_error(
1151+
AmbiguousEnumError(repr_, *self._repr_to_value_and_import[repr_])
1152+
)
1153+
else:
1154+
# fix the invalid enum expressions
1155+
value, import_ = self._repr_to_value_and_import[repr_].pop()
1156+
for arg in args:
1157+
module = self._invalid_default_argument_to_module[arg]
1158+
if module.origin == import_.origin:
1159+
arg.default = Value(
1160+
repr=value.repr[len(str(module.origin)) + 1 :],
1161+
is_print_safe=True,
1162+
)
1163+
else:
1164+
arg.default = value
1165+
module.imports.add(import_)
10351166

1036-
def finalize(self):
10371167
if self._unknown_enum_classes:
1168+
# TODO: does this case still exist in practice? How would pybind11 display
1169+
# a repr for an enum field whose definition we did not see while parsing?
10381170
logger.warning(
10391171
"Enum-like str representations were found with no "
10401172
"matching mapping to the enum class location.\n"

pybind11_stubgen/parser/mixins/parse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def handle_class_member(
8585
def handle_module(
8686
self, path: QualifiedName, module: types.ModuleType
8787
) -> Module | None:
88-
result = Module(name=path[-1])
88+
result = Module(name=path[-1], origin=QualifiedName.from_str(module.__name__))
8989
for name, member in inspect.getmembers(module):
9090
obj = self.handle_module_member(
9191
QualifiedName([*path, Identifier(name)]), module, member

pybind11_stubgen/structs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def parent(self) -> QualifiedName:
5050
return QualifiedName(self[:-1])
5151

5252

53-
@dataclass
53+
@dataclass(eq=False)
5454
class Value:
5555
repr: str
5656
is_print_safe: bool = False # `self.repr` is valid python and safe to print as is
@@ -110,7 +110,7 @@ class Attribute:
110110
annotation: Annotation | None = field_(default=None)
111111

112112

113-
@dataclass
113+
@dataclass(eq=False)
114114
class Argument:
115115
name: Identifier | None
116116
pos_only: bool = field_(default=False)
@@ -191,6 +191,7 @@ class Import:
191191
@dataclass
192192
class Module:
193193
name: Identifier
194+
origin: QualifiedName
194195
doc: Docstring | None = field_(default=None)
195196
classes: list[Class] = field_(default_factory=list)
196197
functions: list[Function] = field_(default_factory=list)

tests/check-demo-stubs-generation.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ run_stubgen() {
4343
demo \
4444
--output-dir=${STUBS_DIR} \
4545
${NUMPY_FORMAT} \
46-
--ignore-invalid-expressions="\(anonymous namespace\)::(Enum|Unbound)|<demo\._bindings\.flawed_bindings\..*" \
47-
--enum-class-locations="ConsoleForegroundColor:demo._bindings.enum" \
46+
--ignore-invalid-expressions="\(anonymous namespace\)::(Enum|Unbound)|<demo\._bindings\.flawed_bindings\..*|<ConsoleForegroundColor\\.Magenta: 35>" \
4847
--print-safe-value-reprs="Foo\(\d+\)" \
4948
--exit-code
5049
}

tests/demo-lib/include/demo/sublibA/ConsoleColors.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@ enum class ConsoleForegroundColor {
99
None_ = -1
1010
};
1111

12+
enum class ConsoleForegroundColorDuplicate {
13+
Green = 32,
14+
Yellow = 33,
15+
Blue = 34,
16+
Magenta = 35,
17+
None_ = -1
18+
};
19+
1220
enum ConsoleBackgroundColor {
1321
Green = 42,
1422
Yellow = 43,

tests/demo.errors.stderr.txt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,15 @@
1-
pybind11_stubgen - [ ERROR] In demo._bindings.aliases.foreign_enum_default : Invalid expression '<ConsoleForegroundColor.Blue: 34>'
21
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.dense_matrix_c : Can't find/import 'm'
32
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.dense_matrix_c : Can't find/import 'n'
43
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.dense_matrix_r : Can't find/import 'm'
54
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.dense_matrix_r : Can't find/import 'n'
65
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.four_col_matrix_r : Can't find/import 'm'
76
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.four_row_matrix_r : Can't find/import 'n'
8-
pybind11_stubgen - [ ERROR] In demo._bindings.enum.accept_defaulted_enum : Invalid expression '<ConsoleForegroundColor.None_: -1>'
97
pybind11_stubgen - [ ERROR] In demo._bindings.flawed_bindings.accept_unbound_enum : Invalid expression '(anonymous namespace)::Enum'
108
pybind11_stubgen - [ ERROR] In demo._bindings.flawed_bindings.accept_unbound_enum_defaulted : Invalid expression '<demo._bindings.flawed_bindings.Enum object at 0x1234abcd5678>'
119
pybind11_stubgen - [ ERROR] In demo._bindings.flawed_bindings.accept_unbound_type : Invalid expression '(anonymous namespace)::Unbound'
1210
pybind11_stubgen - [ ERROR] In demo._bindings.flawed_bindings.accept_unbound_type_defaulted : Invalid expression '<demo._bindings.flawed_bindings.Unbound object at 0x1234abcd5678>'
1311
pybind11_stubgen - [ ERROR] In demo._bindings.flawed_bindings.get_unbound_type : Invalid expression '(anonymous namespace)::Unbound'
14-
pybind11_stubgen - [WARNING] Enum-like str representations were found with no matching mapping to the enum class location.
15-
Use `--enum-class-locations` to specify full path to the following enum(s):
16-
- ConsoleForegroundColor
12+
pybind11_stubgen - [ ERROR] In finalize : Enum member '<ConsoleForegroundColor.Magenta: 35>' could not be resolved; multiple matching definitions found in: 'demo._bindings.duplicate_enum', 'demo._bindings.enum'
1713
pybind11_stubgen - [WARNING] Raw C++ types/values were found in signatures extracted from docstrings.
1814
Please check the corresponding sections of pybind11 documentation to avoid common mistakes in binding code:
1915
- https://pybind11.readthedocs.io/en/latest/advanced/misc.html#avoiding-cpp-types-in-docstrings

tests/py-demo/bindings/src/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ PYBIND11_MODULE(_bindings, m) {
44
bind_classes_module(m.def_submodule("classes"));
55
bind_eigen_module(m.def_submodule("eigen"));
66
bind_enum_module(m.def_submodule("enum"));
7+
bind_duplicate_enum_module(m.def_submodule("duplicate_enum"));
78
bind_aliases_module(m.def_submodule("aliases"));
89
bind_flawed_bindings_module(m.def_submodule("flawed_bindings"));
910
bind_functions_module(m.def_submodule("functions"));

tests/py-demo/bindings/src/modules.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ void bind_aliases_module(py::module&& m);
1010
void bind_classes_module(py::module&& m);
1111
void bind_eigen_module(py::module&& m);
1212
void bind_enum_module(py::module&& m);
13+
void bind_duplicate_enum_module(py::module&& m);
1314
void bind_flawed_bindings_module(py::module&& m);
1415
void bind_functions_module(py::module&& m);
1516
void bind_issues_module(py::module&& m);

0 commit comments

Comments
 (0)