Skip to content

Commit 1b04a81

Browse files
authored
Support [list]-like annotations (#184)
Fix #183
1 parent afc508f commit 1b04a81

File tree

9 files changed

+75
-39
lines changed

9 files changed

+75
-39
lines changed

pybind11_stubgen/parser/mixins/fix.py

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -131,17 +131,18 @@ def parse_annotation_str(
131131
return result
132132

133133
def _add_import(self, name: QualifiedName) -> None:
134-
if len(name) > 0:
135-
if hasattr(builtins, name[0]):
136-
return
137-
if self.__current_class is not None and hasattr(
138-
self.__current_class, name[0]
139-
):
140-
return
141-
if self.__current_module is not None and hasattr(
142-
self.__current_module, name[0]
143-
):
144-
return
134+
if len(name) == 0:
135+
return
136+
if len(name) == 1 and len(name[0]) == 0:
137+
return
138+
if hasattr(builtins, name[0]):
139+
return
140+
if self.__current_class is not None and hasattr(self.__current_class, name[0]):
141+
return
142+
if self.__current_module is not None and hasattr(
143+
self.__current_module, name[0]
144+
):
145+
return
145146
module_name = self._get_parent_module(name)
146147
if module_name is None:
147148
self.report_error(NameResolutionError(name))
@@ -495,6 +496,8 @@ class FixNumpyArrayDimAnnotation(IParser):
495496
)
496497
)
497498

499+
__DIM_VARS = ["n", "m"]
500+
498501
def parse_annotation_str(
499502
self, annotation_str: str
500503
) -> ResolvedType | InvalidExpression | Value:
@@ -518,12 +521,6 @@ def parse_annotation_str(
518521
if (
519522
not isinstance(scalar_with_dims, ResolvedType)
520523
or scalar_with_dims.name not in self.numpy_primitive_types
521-
or (
522-
scalar_with_dims.parameters is not None
523-
and any(
524-
not isinstance(dim, Value) for dim in scalar_with_dims.parameters
525-
)
526-
)
527524
):
528525
return result
529526

@@ -540,38 +537,57 @@ def parse_annotation_str(
540537
scalar_with_dims.parameters is not None
541538
and len(scalar_with_dims.parameters) >= 0
542539
):
543-
result.parameters += [
544-
self.handle_value(
545-
self._cook_dimension_parameters(scalar_with_dims.parameters)
546-
)
547-
]
540+
dims = self.__to_dims(scalar_with_dims.parameters)
541+
if dims is not None and len(dims) > 0:
542+
result.parameters += [
543+
self.handle_value(self.__wrap_with_size_helper(dims))
544+
]
548545

549546
result.parameters += flags
550547

551548
return result
552549

553-
def _cook_dimension_parameters(
554-
self, dimensions: list[Value]
555-
) -> FixedSize | DynamicSize:
556-
all_ints = True
557-
new_params = []
558-
for dim_param in dimensions:
559-
try:
560-
dim = int(dim_param.repr)
561-
except ValueError:
562-
dim = dim_param.repr
563-
all_ints = False
564-
new_params.append(dim)
565-
566-
if all_ints:
550+
def __wrap_with_size_helper(self, dims: list[int | str]) -> FixedSize | DynamicSize:
551+
if all(isinstance(d, int) for d in dims):
567552
return_t = FixedSize
568553
else:
569554
return_t = DynamicSize
570555

571556
# TRICK: Use `self.handle_type` to make `FixedSize`/`DynamicSize`
572557
# properly added to the list of imports
573-
self.handle_type(FixedSize)
574-
return return_t(*new_params)
558+
self.handle_type(return_t)
559+
return return_t(*dims)
560+
561+
def __to_dims(
562+
self, dimensions: list[ResolvedType | Value | InvalidExpression]
563+
) -> list[int | str] | None:
564+
result = []
565+
for dim_param in dimensions:
566+
if isinstance(dim_param, Value):
567+
try:
568+
dim = int(dim_param.repr)
569+
except ValueError:
570+
return None
571+
elif isinstance(dim_param, ResolvedType):
572+
dim = str(dim_param)
573+
if dim not in self.__DIM_VARS:
574+
return None
575+
else:
576+
return None
577+
result.append(dim)
578+
return result
579+
580+
def report_error(self, error: ParserError) -> None:
581+
582+
if (
583+
isinstance(error, NameResolutionError)
584+
and len(error.name) == 1
585+
and len(error.name[0]) == 1
586+
and error.name[0] in self.__DIM_VARS
587+
):
588+
# Ignores all unknown 'm' and 'n' regardless of the context
589+
return
590+
super().report_error(error)
575591

576592

577593
class FixNumpyArrayRemoveParameters(IParser):

pybind11_stubgen/parser/mixins/parse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def parse_type_str(
581581
self, annotation_str: str
582582
) -> ResolvedType | InvalidExpression | Value:
583583
qname_regex = re.compile(
584-
r"^\s*(?P<qual_name>[_A-Za-z]\w+(\s*\.\s*[_A-Za-z]\w+)*)"
584+
r"^\s*(?P<qual_name>([_A-Za-z]\w*)?(\s*\.\s*[_A-Za-z]\w*)*)"
585585
)
586586
annotation_str = annotation_str.strip()
587587
match = qname_regex.match(annotation_str)

tests/demo.errors.stderr.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
pybind11_stubgen - [ ERROR] In demo._bindings.aliases.foreign_enum_default : Invalid expression '<ConsoleForegroundColor.Blue: 34>'
2+
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.dense_matrix_c : Can't find/import 'm'
3+
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.dense_matrix_c : Can't find/import 'n'
4+
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.dense_matrix_r : Can't find/import 'm'
5+
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.dense_matrix_r : Can't find/import 'n'
6+
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.four_col_matrix_r : Can't find/import 'm'
7+
pybind11_stubgen - [ ERROR] In demo._bindings.eigen.four_row_matrix_r : Can't find/import 'n'
28
pybind11_stubgen - [ ERROR] In demo._bindings.enum.accept_defaulted_enum : Invalid expression '<ConsoleForegroundColor.None_: -1>'
39
pybind11_stubgen - [ ERROR] In demo._bindings.flawed_bindings.accept_unbound_enum : Invalid expression '(anonymous namespace)::Enum'
410
pybind11_stubgen - [ ERROR] In demo._bindings.flawed_bindings.accept_unbound_enum_defaulted : Invalid expression '<demo._bindings.flawed_bindings.Enum object at 0x1234abcd5678>'

tests/py-demo/bindings/src/modules/functions.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
#include "modules.h"
22

33
#if PYBIND11_VERSION_AT_LEAST(2, 12)
4+
45
#include <pybind11/typing.h>
6+
57
#endif
68

79
#include <pybind11/stl.h>
10+
#include <pybind11/functional.h>
811

912
#include <demo/sublibA/add.h>
1013

@@ -93,4 +96,5 @@ void bind_functions_module(py::module &&m) {
9396
py::class_<Foo> pyFoo(m, "Foo");
9497
pyFoo.def(py::init<int>());
9598
m.def("default_custom_arg", [](Foo &foo) {}, py::arg_v("foo", Foo(5), "Foo(5)"));
99+
m.def("pass_callback", [](std::function<Foo(Foo &)> &callback) { return Foo(13); });
96100
}

tests/stubs/python-3.12/pybind11-master/demo/_bindings/functions.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ __all__ = [
1919
"func_w_named_pos_args",
2020
"generic",
2121
"mul",
22+
"pass_callback",
2223
"pos_kw_only_mix",
2324
"pos_kw_only_variadic_mix",
2425
]
@@ -52,5 +53,6 @@ def mul(p: float, q: float) -> float:
5253
Multiply p and q (double)
5354
"""
5455

56+
def pass_callback(arg0: typing.Callable[[Foo], Foo]) -> Foo: ...
5557
def pos_kw_only_mix(i: int, /, j: int, *, k: int) -> tuple: ...
5658
def pos_kw_only_variadic_mix(i: int, /, j: int, *args, k: int, **kwargs) -> tuple: ...

tests/stubs/python-3.12/pybind11-v2.11/demo/_bindings/functions.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ __all__ = [
1818
"func_w_named_pos_args",
1919
"generic",
2020
"mul",
21+
"pass_callback",
2122
"pos_kw_only_mix",
2223
"pos_kw_only_variadic_mix",
2324
]
@@ -50,5 +51,6 @@ def mul(p: float, q: float) -> float:
5051
Multiply p and q (double)
5152
"""
5253

54+
def pass_callback(arg0: typing.Callable[[Foo], Foo]) -> Foo: ...
5355
def pos_kw_only_mix(i: int, /, j: int, *, k: int) -> tuple: ...
5456
def pos_kw_only_variadic_mix(i: int, /, j: int, *args, k: int, **kwargs) -> tuple: ...

tests/stubs/python-3.12/pybind11-v2.9/demo/_bindings/functions.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ __all__ = [
1717
"func_w_named_pos_args",
1818
"generic",
1919
"mul",
20+
"pass_callback",
2021
"pos_kw_only_mix",
2122
"pos_kw_only_variadic_mix",
2223
]
@@ -48,5 +49,6 @@ def mul(p: float, q: float) -> float:
4849
Multiply p and q (double)
4950
"""
5051

52+
def pass_callback(arg0: typing.Callable[[Foo], Foo]) -> Foo: ...
5153
def pos_kw_only_mix(i: int, /, j: int, *, k: int) -> tuple: ...
5254
def pos_kw_only_variadic_mix(i: int, /, j: int, *args, k: int, **kwargs) -> tuple: ...

tests/stubs/python-3.7/pybind11-master/demo/_bindings/functions.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ __all__ = [
1919
"func_w_named_pos_args",
2020
"generic",
2121
"mul",
22+
"pass_callback",
2223
"pos_kw_only_mix",
2324
"pos_kw_only_variadic_mix",
2425
]
@@ -52,5 +53,6 @@ def mul(p: float, q: float) -> float:
5253
Multiply p and q (double)
5354
"""
5455

56+
def pass_callback(arg0: typing.Callable[[Foo], Foo]) -> Foo: ...
5557
def pos_kw_only_mix(i: int, j: int, *, k: int) -> tuple: ...
5658
def pos_kw_only_variadic_mix(i: int, j: int, *args, k: int, **kwargs) -> tuple: ...

tests/stubs/python-3.8/pybind11-master/demo/_bindings/functions.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ __all__ = [
1919
"func_w_named_pos_args",
2020
"generic",
2121
"mul",
22+
"pass_callback",
2223
"pos_kw_only_mix",
2324
"pos_kw_only_variadic_mix",
2425
]
@@ -52,5 +53,6 @@ def mul(p: float, q: float) -> float:
5253
Multiply p and q (double)
5354
"""
5455

56+
def pass_callback(arg0: typing.Callable[[Foo], Foo]) -> Foo: ...
5557
def pos_kw_only_mix(i: int, /, j: int, *, k: int) -> tuple: ...
5658
def pos_kw_only_variadic_mix(i: int, /, j: int, *args, k: int, **kwargs) -> tuple: ...

0 commit comments

Comments
 (0)