From 2585a5dd3a5aeb2bd11e3d6be492d0916f99f1c1 Mon Sep 17 00:00:00 2001 From: gentlegiantJGC Date: Mon, 15 Sep 2025 10:38:20 +0100 Subject: [PATCH] Remove current module from nested types (#18) * Remove current module from nested types Types like union types have nested types in them. The current implementation only removes the current module name from the root type name. This modifies that behaviour so that it expands nested types recursively. * Add tests * Reorder variant types pybind11 requires that arguments are default constructable but Foo is not * Reorder types in union * Add nested types to all * Add missing stubs --- pybind11_stubgen/parser/mixins/fix.py | 8 +++++++- tests/py-demo/bindings/src/modules/functions.cpp | 4 ++++ .../demo/_bindings/functions.pyi | 2 ++ .../demo/_bindings/functions.pyi | 2 ++ .../numpy-array-use-type-var/demo/_bindings/functions.pyi | 2 ++ .../demo/_bindings/functions.pyi | 2 ++ .../demo/_bindings/functions.pyi | 2 ++ .../demo/_bindings/functions.pyi | 2 ++ 8 files changed, 23 insertions(+), 1 deletion(-) diff --git a/pybind11_stubgen/parser/mixins/fix.py b/pybind11_stubgen/parser/mixins/fix.py index 87898906..56d49688 100644 --- a/pybind11_stubgen/parser/mixins/fix.py +++ b/pybind11_stubgen/parser/mixins/fix.py @@ -469,9 +469,15 @@ def parse_annotation_str( ) -> ResolvedType | InvalidExpression | Value: result = super().parse_annotation_str(annotation_str) if isinstance(result, ResolvedType): - result.name = self._strip_current_module(result.name) + self._strip_current_module_recursive(result) return result + def _strip_current_module_recursive(self, result: ResolvedType): + result.name = self._strip_current_module(result.name) + for parameter in result.parameters or (): + if isinstance(parameter, ResolvedType): + self._strip_current_module_recursive(parameter) + def _strip_current_module(self, name: QualifiedName) -> QualifiedName: if name[: len(self.__current_module)] == self.__current_module: return QualifiedName(name[len(self.__current_module) :]) diff --git a/tests/py-demo/bindings/src/modules/functions.cpp b/tests/py-demo/bindings/src/modules/functions.cpp index 969a4fba..15488da6 100644 --- a/tests/py-demo/bindings/src/modules/functions.cpp +++ b/tests/py-demo/bindings/src/modules/functions.cpp @@ -9,6 +9,9 @@ #include #include +#include +#include + #include namespace { @@ -97,4 +100,5 @@ void bind_functions_module(py::module &&m) { pyFoo.def(py::init()); m.def("default_custom_arg", [](Foo &foo) {}, py::arg_v("foo", Foo(5), "Foo(5)")); m.def("pass_callback", [](std::function &callback) { return Foo(13); }); + m.def("nested_types", [](std::variant, Foo> arg){ return arg; }); } diff --git a/tests/stubs/python-3.12/pybind11-v2.11/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi b/tests/stubs/python-3.12/pybind11-v2.11/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi index f6e0d53b..316386c8 100644 --- a/tests/stubs/python-3.12/pybind11-v2.11/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi +++ b/tests/stubs/python-3.12/pybind11-v2.11/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi @@ -18,6 +18,7 @@ __all__: list[str] = [ "func_w_named_pos_args", "generic", "mul", + "nested_types", "pass_callback", "pos_kw_only_mix", "pos_kw_only_variadic_mix", @@ -51,6 +52,7 @@ def mul(p: float, q: float) -> float: Multiply p and q (double) """ +def nested_types(arg0: list[Foo] | Foo) -> list[Foo] | Foo: ... def pass_callback(arg0: typing.Callable[[Foo], Foo]) -> Foo: ... def pos_kw_only_mix(i: int, /, j: int, *, k: int) -> tuple: ... def pos_kw_only_variadic_mix(i: int, /, j: int, *args, k: int, **kwargs) -> tuple: ... diff --git a/tests/stubs/python-3.12/pybind11-v2.12/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi b/tests/stubs/python-3.12/pybind11-v2.12/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi index 79f666a0..72ab6498 100644 --- a/tests/stubs/python-3.12/pybind11-v2.12/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi +++ b/tests/stubs/python-3.12/pybind11-v2.12/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi @@ -19,6 +19,7 @@ __all__: list[str] = [ "func_w_named_pos_args", "generic", "mul", + "nested_types", "pass_callback", "pos_kw_only_mix", "pos_kw_only_variadic_mix", @@ -53,6 +54,7 @@ def mul(p: float, q: float) -> float: Multiply p and q (double) """ +def nested_types(arg0: list[Foo] | Foo) -> list[Foo] | Foo: ... def pass_callback(arg0: typing.Callable[[Foo], Foo]) -> Foo: ... def pos_kw_only_mix(i: int, /, j: int, *, k: int) -> tuple: ... def pos_kw_only_variadic_mix(i: int, /, j: int, *args, k: int, **kwargs) -> tuple: ... diff --git a/tests/stubs/python-3.12/pybind11-v2.13/numpy-array-use-type-var/demo/_bindings/functions.pyi b/tests/stubs/python-3.12/pybind11-v2.13/numpy-array-use-type-var/demo/_bindings/functions.pyi index 79f666a0..72ab6498 100644 --- a/tests/stubs/python-3.12/pybind11-v2.13/numpy-array-use-type-var/demo/_bindings/functions.pyi +++ b/tests/stubs/python-3.12/pybind11-v2.13/numpy-array-use-type-var/demo/_bindings/functions.pyi @@ -19,6 +19,7 @@ __all__: list[str] = [ "func_w_named_pos_args", "generic", "mul", + "nested_types", "pass_callback", "pos_kw_only_mix", "pos_kw_only_variadic_mix", @@ -53,6 +54,7 @@ def mul(p: float, q: float) -> float: Multiply p and q (double) """ +def nested_types(arg0: list[Foo] | Foo) -> list[Foo] | Foo: ... def pass_callback(arg0: typing.Callable[[Foo], Foo]) -> Foo: ... def pos_kw_only_mix(i: int, /, j: int, *, k: int) -> tuple: ... def pos_kw_only_variadic_mix(i: int, /, j: int, *args, k: int, **kwargs) -> tuple: ... diff --git a/tests/stubs/python-3.12/pybind11-v2.13/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi b/tests/stubs/python-3.12/pybind11-v2.13/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi index 79f666a0..72ab6498 100644 --- a/tests/stubs/python-3.12/pybind11-v2.13/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi +++ b/tests/stubs/python-3.12/pybind11-v2.13/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi @@ -19,6 +19,7 @@ __all__: list[str] = [ "func_w_named_pos_args", "generic", "mul", + "nested_types", "pass_callback", "pos_kw_only_mix", "pos_kw_only_variadic_mix", @@ -53,6 +54,7 @@ def mul(p: float, q: float) -> float: Multiply p and q (double) """ +def nested_types(arg0: list[Foo] | Foo) -> list[Foo] | Foo: ... def pass_callback(arg0: typing.Callable[[Foo], Foo]) -> Foo: ... def pos_kw_only_mix(i: int, /, j: int, *, k: int) -> tuple: ... def pos_kw_only_variadic_mix(i: int, /, j: int, *args, k: int, **kwargs) -> tuple: ... diff --git a/tests/stubs/python-3.12/pybind11-v2.9/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi b/tests/stubs/python-3.12/pybind11-v2.9/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi index d5d208be..2f467b68 100644 --- a/tests/stubs/python-3.12/pybind11-v2.9/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi +++ b/tests/stubs/python-3.12/pybind11-v2.9/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi @@ -17,6 +17,7 @@ __all__: list[str] = [ "func_w_named_pos_args", "generic", "mul", + "nested_types", "pass_callback", "pos_kw_only_mix", "pos_kw_only_variadic_mix", @@ -49,6 +50,7 @@ def mul(p: float, q: float) -> float: Multiply p and q (double) """ +def nested_types(arg0: list[Foo] | Foo) -> list[Foo] | Foo: ... def pass_callback(arg0: typing.Callable[[Foo], Foo]) -> Foo: ... def pos_kw_only_mix(i: int, /, j: int, *, k: int) -> tuple: ... def pos_kw_only_variadic_mix(i: int, /, j: int, *args, k: int, **kwargs) -> tuple: ... diff --git a/tests/stubs/python-3.8/pybind11-v2.13/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi b/tests/stubs/python-3.8/pybind11-v2.13/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi index 79f666a0..72ab6498 100644 --- a/tests/stubs/python-3.8/pybind11-v2.13/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi +++ b/tests/stubs/python-3.8/pybind11-v2.13/numpy-array-wrap-with-annotated/demo/_bindings/functions.pyi @@ -19,6 +19,7 @@ __all__: list[str] = [ "func_w_named_pos_args", "generic", "mul", + "nested_types", "pass_callback", "pos_kw_only_mix", "pos_kw_only_variadic_mix", @@ -53,6 +54,7 @@ def mul(p: float, q: float) -> float: Multiply p and q (double) """ +def nested_types(arg0: list[Foo] | Foo) -> list[Foo] | Foo: ... def pass_callback(arg0: typing.Callable[[Foo], Foo]) -> Foo: ... def pos_kw_only_mix(i: int, /, j: int, *, k: int) -> tuple: ... def pos_kw_only_variadic_mix(i: int, /, j: int, *args, k: int, **kwargs) -> tuple: ...