Skip to content

Commit 27d16df

Browse files
authored
Reconstruct defaulted arg values from docstring (#147)
Changes: - Recognize Python literal strings in default arg as "print-safe" - Add `--print-safe-value-reprs=REGEX` CLI option to override the print-safe flag of `Value` (for custom reprs provided via `pybind11::arg_v()`) - Add `--enum-class-locations=REGEX:LOC` CLI option to rewrite enum values as valid Python expressions with correct imports. This change introduces new errors since earlier Enum-like representations (e.g. `<MyEnum.Value: 1>`) were treated as non-printable `Value`s and rendered as `...`.
1 parent ca280d8 commit 27d16df

File tree

20 files changed

+265
-59
lines changed

20 files changed

+265
-59
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@ pybind11-stubgen [-h]
2828
[--ignore-invalid-identifiers REGEX]
2929
[--ignore-unresolved-names REGEX]
3030
[--ignore-all-errors]
31+
[--enum-class-locations [REGEX:LOC ...]]
3132
[--numpy-array-wrap-with-annotated|
3233
--numpy-array-remove-parameters]
3334
[--print-invalid-expressions-as-is]
35+
[--print-safe-value-reprs REGEX]
3436
[--exit-code]
3537
[--stub-extension EXT]
3638
MODULE_NAME

pybind11_stubgen/__init__.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@
4242
FixTypingExtTypeNames,
4343
FixTypingTypeNames,
4444
FixValueReprRandomAddress,
45+
OverridePrintSafeValues,
4546
RemoveSelfAnnotation,
4647
ReplaceReadWritePropertyWithField,
48+
RewritePybind11EnumValueRepr,
4749
)
4850
from pybind11_stubgen.parser.mixins.parse import (
4951
BaseParser,
@@ -62,6 +64,12 @@ def regex(pattern_str: str) -> re.Pattern:
6264
except re.error as e:
6365
raise ValueError(f"Invalid REGEX pattern: {e}")
6466

67+
def regex_colon_path(regex_path: str) -> tuple[re.Pattern, str]:
68+
pattern_str, path = regex_path.rsplit(":", maxsplit=1)
69+
if any(not part.isidentifier() for part in path.split(".")):
70+
raise ValueError(f"Invalid PATH: {path}")
71+
return regex(pattern_str), path
72+
6573
parser = ArgumentParser(
6674
prog="pybind11-stubgen", description="Generates stubs for specified modules"
6775
)
@@ -109,6 +117,18 @@ def regex(pattern_str: str) -> re.Pattern:
109117
help="Ignore all errors during module parsing",
110118
)
111119

120+
parser.add_argument(
121+
"--enum-class-locations",
122+
dest="enum_class_locations",
123+
metavar="REGEX:LOC",
124+
default=[],
125+
nargs="*",
126+
type=regex_colon_path,
127+
help="Locations of enum classes in "
128+
"<enum-class-name-regex>:<path-to-class> format. "
129+
"Example: `MyEnum:foo.bar.Baz`",
130+
)
131+
112132
numpy_array_fix = parser.add_mutually_exclusive_group()
113133
numpy_array_fix.add_argument(
114134
"--numpy-array-wrap-with-annotated",
@@ -133,6 +153,14 @@ def regex(pattern_str: str) -> re.Pattern:
133153
help="Suppress invalid expression replacement with '...'",
134154
)
135155

156+
parser.add_argument(
157+
"--print-safe-value-reprs",
158+
metavar="REGEX",
159+
default=None,
160+
type=regex,
161+
help="Override the print-safe check for values matching REGEX",
162+
)
163+
136164
parser.add_argument(
137165
"--exit-code",
138166
action="store_true",
@@ -202,10 +230,12 @@ class Parser(
202230
FixTypingExtTypeNames,
203231
FixMissingFixedSizeImport,
204232
FixMissingEnumMembersAnnotation,
233+
OverridePrintSafeValues,
205234
*numpy_fixes, # type: ignore[misc]
206235
FixNumpyArrayFlags,
207236
FixCurrentModulePrefixInTypeNames,
208237
FixBuiltinTypes,
238+
RewritePybind11EnumValueRepr,
209239
FilterClassMembers,
210240
ReplaceReadWritePropertyWithField,
211241
FilterInvalidIdentifiers,
@@ -224,12 +254,16 @@ class Parser(
224254

225255
parser = Parser()
226256

257+
if args.enum_class_locations:
258+
parser.set_pybind11_enum_locations(dict(args.enum_class_locations))
227259
if args.ignore_invalid_identifiers is not None:
228260
parser.set_ignored_invalid_identifiers(args.ignore_invalid_identifiers)
229261
if args.ignore_invalid_expressions is not None:
230262
parser.set_ignored_invalid_expressions(args.ignore_invalid_expressions)
231263
if args.ignore_unresolved_names is not None:
232264
parser.set_ignored_unresolved_names(args.ignore_unresolved_names)
265+
if args.print_safe_value_reprs is not None:
266+
parser.set_print_safe_value_pattern(args.print_safe_value_reprs)
233267
return parser
234268

235269

pybind11_stubgen/parser/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def parse_annotation_str(
106106
...
107107

108108
@abc.abstractmethod
109-
def parse_value_str(self, value: str) -> Value:
109+
def parse_value_str(self, value: str) -> Value | InvalidExpression:
110110
...
111111

112112
@abc.abstractmethod

pybind11_stubgen/parser/mixins/fix.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import inspect
66
import re
77
import types
8+
from logging import getLogger
89
from typing import Any
910

1011
from pybind11_stubgen.parser.errors import NameResolutionError, ParserError
@@ -29,6 +30,8 @@
2930
)
3031
from pybind11_stubgen.typing_ext import DynamicSize, FixedSize
3132

33+
logger = getLogger("pybind11_stubgen")
34+
3235

3336
class RemoveSelfAnnotation(IParser):
3437
def handle_method(self, path: QualifiedName, method: Any) -> list[Method]:
@@ -363,7 +366,7 @@ class FixTypingExtTypeNames(IParser):
363366
__typing_names: set[Identifier] = set(
364367
map(
365368
Identifier,
366-
["buffer"],
369+
["buffer", "Buffer"],
367370
)
368371
)
369372

@@ -751,3 +754,63 @@ def handle_class_member(
751754
method.modifier = None
752755
method.function.doc = None
753756
return result
757+
758+
759+
class OverridePrintSafeValues(IParser):
760+
_print_safe_values: re.Pattern | None
761+
762+
def __init__(self):
763+
super().__init__()
764+
self._print_safe_values = None
765+
766+
def set_print_safe_value_pattern(self, pattern: re.Pattern):
767+
self._print_safe_values = pattern
768+
769+
def parse_value_str(self, value: str) -> Value | InvalidExpression:
770+
result = super().parse_value_str(value)
771+
if (
772+
self._print_safe_values is not None
773+
and isinstance(result, Value)
774+
and not result.is_print_safe
775+
and self._print_safe_values.match(result.repr) is not None
776+
):
777+
result.is_print_safe = True
778+
return result
779+
780+
781+
class RewritePybind11EnumValueRepr(IParser):
782+
_pybind11_enum_pattern = re.compile(r"<(?P<enum>\w+(\.\w+)+): (?P<value>\d+)>")
783+
_unknown_enum_classes: set[str] = set()
784+
785+
def __init__(self):
786+
super().__init__()
787+
self._pybind11_enum_locations: dict[re.Pattern, str] = {}
788+
789+
def set_pybind11_enum_locations(self, locations: dict[re.Pattern, str]):
790+
self._pybind11_enum_locations = locations
791+
792+
def parse_value_str(self, value: str) -> Value | InvalidExpression:
793+
value = value.strip()
794+
match = self._pybind11_enum_pattern.match(value)
795+
if match is not None:
796+
enum_qual_name = match.group("enum")
797+
enum_class_str, entry = enum_qual_name.rsplit(".", maxsplit=1)
798+
for pattern, prefix in self._pybind11_enum_locations.items():
799+
if pattern.match(enum_class_str) is None:
800+
continue
801+
enum_class = self.parse_annotation_str(f"{prefix}.{enum_class_str}")
802+
if isinstance(enum_class, ResolvedType):
803+
return Value(repr=f"{enum_class.name}.{entry}", is_print_safe=True)
804+
self._unknown_enum_classes.add(enum_class_str)
805+
return super().parse_value_str(value)
806+
807+
def finalize(self):
808+
if self._unknown_enum_classes:
809+
logger.warning(
810+
"Enum-like str representations were found with no "
811+
"matching mapping to the enum class location.\n"
812+
"Use `--enum-class-locations` to specify "
813+
"full path to the following enum(s):\n"
814+
+ "\n".join(f" - {c}" for c in self._unknown_enum_classes)
815+
)
816+
super().finalize()

pybind11_stubgen/parser/mixins/parse.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,8 @@ def handle_type(self, type_: type) -> QualifiedName:
380380
)
381381
)
382382

383-
def parse_value_str(self, value: str) -> Value:
384-
return Value(value)
383+
def parse_value_str(self, value: str) -> Value | InvalidExpression:
384+
return self._parse_expression_str(value)
385385

386386
def report_error(self, error: ParserError):
387387
if isinstance(error, NameResolutionError):
@@ -428,6 +428,21 @@ def _get_full_name(self, path: QualifiedName, origin: Any) -> QualifiedName | No
428428
return None
429429
return origin_name
430430

431+
def _parse_expression_str(self, expr_str: str) -> Value | InvalidExpression:
432+
strip_expr = expr_str.strip()
433+
try:
434+
ast.parse(strip_expr)
435+
print_safe = False
436+
try:
437+
ast.literal_eval(strip_expr)
438+
print_safe = True
439+
except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
440+
pass
441+
return Value(strip_expr, is_print_safe=print_safe)
442+
except SyntaxError:
443+
self.report_error(InvalidExpressionError(strip_expr))
444+
return InvalidExpression(strip_expr)
445+
431446

432447
class ExtractSignaturesFromPybind11Docstrings(IParser):
433448
_arg_star_name_regex = re.compile(
@@ -571,7 +586,7 @@ def parse_type_str(
571586
annotation_str = annotation_str.strip()
572587
match = qname_regex.match(annotation_str)
573588
if match is None:
574-
return self._parse_expression_str(annotation_str)
589+
return self.parse_value_str(annotation_str)
575590
qual_name = QualifiedName(
576591
Identifier(part)
577592
for part in match.group("qual_name").replace(" ", "").split(".")
@@ -582,25 +597,17 @@ def parse_type_str(
582597
parameters = None
583598
else:
584599
if parameters_str[0] != "[" or parameters_str[-1] != "]":
585-
return self._parse_expression_str(annotation_str)
600+
return self.parse_value_str(annotation_str)
586601

587602
split_parameters = self._split_parameters_str(parameters_str[1:-1])
588603
if split_parameters is None:
589-
return self._parse_expression_str(annotation_str)
604+
return self.parse_value_str(annotation_str)
590605

591606
parameters = [
592607
self.parse_annotation_str(param_str) for param_str in split_parameters
593608
]
594609
return ResolvedType(name=qual_name, parameters=parameters)
595610

596-
def _parse_expression_str(self, expr_str: str) -> Value | InvalidExpression:
597-
try:
598-
ast.parse(expr_str)
599-
return self.parse_value_str(expr_str)
600-
except SyntaxError:
601-
self.report_error(InvalidExpressionError(expr_str))
602-
return InvalidExpression(expr_str)
603-
604611
def parse_function_docstring(
605612
self, func_name: Identifier, doc_lines: list[str]
606613
) -> list[Function]:

pybind11_stubgen/printer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,13 @@ def print_argument(self, arg: Argument) -> str:
6161
parts.append(f"{arg.name}")
6262
if arg.annotation is not None:
6363
parts.append(f": {self.print_annotation(arg.annotation)}")
64-
if arg.default is not None:
64+
if isinstance(arg.default, Value):
6565
if arg.default.is_print_safe:
6666
parts.append(f" = {self.print_value(arg.default)}")
6767
else:
6868
parts.append(" = ...")
69+
elif isinstance(arg.default, InvalidExpression):
70+
parts.append(f" = {self.print_invalid_exp(arg.default)}")
6971

7072
return "".join(parts)
7173

pybind11_stubgen/structs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class Argument:
9797
kw_only: bool = field_(default=False)
9898
variadic: bool = field_(default=False) # *args
9999
kw_variadic: bool = field_(default=False) # **kwargs
100-
default: Value | None = field_(default=None)
100+
default: Value | InvalidExpression | None = field_(default=None)
101101
annotation: Annotation | None = field_(default=None)
102102

103103
def __str__(self):

tests/check-demo-stubs-generation.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ run_stubgen() {
3838
demo \
3939
--output-dir=${STUBS_DIR} \
4040
--numpy-array-wrap-with-annotated \
41-
--ignore-invalid-expressions="\(anonymous namespace\)::(Enum|Unbound)" \
41+
--ignore-invalid-expressions="\(anonymous namespace\)::(Enum|Unbound)|<demo\._bindings\.flawed_bindings\..*" \
4242
--ignore-unresolved-names="typing\.Annotated" \
43+
--enum-class-locations="ConsoleForegroundColor:demo._bindings.enum" \
44+
--print-safe-value-reprs="Foo\(\d+\)" \
4345
--exit-code
4446
}
4547

tests/py-demo/bindings/src/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
PYBIND11_MODULE(_bindings, m) {
44
bind_classes_module(m.def_submodule("classes"));
5-
bind_aliases_module(m.def_submodule("aliases"));
65
bind_eigen_module(m.def_submodule("eigen"));
76
bind_enum_module(m.def_submodule("enum"));
7+
bind_aliases_module(m.def_submodule("aliases"));
88
bind_flawed_bindings_module(m.def_submodule("flawed_bindings"));
99
bind_functions_module(m.def_submodule("functions"));
1010
bind_issues_module(m.def_submodule("issues"));

tests/py-demo/bindings/src/modules/aliases.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
#include "modules.h"
22

33
#include <demo/Foo.h>
4+
#include <demo/sublibA/ConsoleColors.h>
45

56
namespace {
6-
class Dummy {};
7+
class Dummy {
8+
};
79

8-
struct Color {};
10+
struct Color {
11+
};
912

10-
struct Bar1 {};
11-
struct Bar2 {};
12-
struct Bar3 {};
13+
struct Bar1 {
14+
};
15+
struct Bar2 {
16+
};
17+
struct Bar3 {
18+
};
1319
} // namespace
1420

1521
void bind_aliases_module(py::module_ &&m) {
@@ -18,7 +24,7 @@ void bind_aliases_module(py::module_ &&m) {
1824
auto &&pyDummy = py::class_<Dummy>(m, "Dummy");
1925

2026
pyDummy.def_property_readonly_static(
21-
"linalg", [](py::object &) { return py::module::import("numpy.linalg"); });
27+
"linalg", [](py::object &) { return py::module::import("numpy.linalg"); });
2228

2329
m.add_object("random", py::module::import("numpy.random"));
2430
}
@@ -63,4 +69,10 @@ void bind_aliases_module(py::module_ &&m) {
6369
m.attr("foreign_type_alias") = m.attr("foreign_method_arg").attr("Bar2");
6470
m.attr("foreign_class_alias") = m.attr("foreign_return").attr("get_foo");
6571
}
72+
73+
m.def(
74+
"foreign_enum_default",
75+
[](const py::object & /* color */) {},
76+
py::arg("color") = demo::sublibA::ConsoleForegroundColor::Blue
77+
);
6678
}

0 commit comments

Comments
 (0)