Skip to content

Commit d9ab41d

Browse files
authored
Append numpy prefix to dtype (#179)
1 parent f5b044d commit d9ab41d

File tree

8 files changed

+33
-0
lines changed

8 files changed

+33
-0
lines changed

pybind11_stubgen/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
FixNumpyArrayDimAnnotation,
3636
FixNumpyArrayFlags,
3737
FixNumpyArrayRemoveParameters,
38+
FixNumpyDtype,
3839
FixPEP585CollectionNames,
3940
FixPybind11EnumStrDoc,
4041
FixRedundantBuiltinsAnnotation,
@@ -231,6 +232,7 @@ class Parser(
231232
FixMissingEnumMembersAnnotation,
232233
OverridePrintSafeValues,
233234
*numpy_fixes, # type: ignore[misc]
235+
FixNumpyDtype,
234236
FixNumpyArrayFlags,
235237
FixCurrentModulePrefixInTypeNames,
236238
FixBuiltinTypes,

pybind11_stubgen/parser/mixins/fix.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,26 @@ def parse_annotation_str(
586586
return result
587587

588588

589+
class FixNumpyDtype(IParser):
590+
__numpy_dtype = QualifiedName.from_str("numpy.dtype")
591+
592+
def parse_annotation_str(
593+
self, annotation_str: str
594+
) -> ResolvedType | InvalidExpression | Value:
595+
result = super().parse_annotation_str(annotation_str)
596+
if (
597+
not isinstance(result, ResolvedType)
598+
or len(result.name) != 1
599+
or result.parameters is not None
600+
):
601+
return result
602+
603+
word = result.name[0]
604+
if word != Identifier("dtype"):
605+
return result
606+
return ResolvedType(name=self.__numpy_dtype)
607+
608+
589609
class FixNumpyArrayFlags(IParser):
590610
__ndarray_name = QualifiedName.from_str("numpy.ndarray")
591611
__flags: set[QualifiedName] = {

tests/py-demo/bindings/src/modules/numpy.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ void bind_numpy_module(py::module&&m) {
99
m.def("accept_ndarray_int", [](py::array_t<int> &) {});
1010
m.def("accept_ndarray_float64", [](py::array_t<double> &) {});
1111
}
12+
m.def("return_dtype", []() { return py::dtype("<i4"); });
1213
}

tests/stubs/python-3.12/pybind11-master/demo/_bindings/numpy.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ __all__ = [
99
"accept_ndarray_int",
1010
"get_ndarray_float64",
1111
"get_ndarray_int",
12+
"return_dtype",
1213
]
1314

1415
def accept_ndarray_float64(
@@ -17,3 +18,4 @@ def accept_ndarray_float64(
1718
def accept_ndarray_int(arg0: typing.Annotated[numpy.ndarray, numpy.int32]) -> None: ...
1819
def get_ndarray_float64() -> typing.Annotated[numpy.ndarray, numpy.float64]: ...
1920
def get_ndarray_int() -> typing.Annotated[numpy.ndarray, numpy.int32]: ...
21+
def return_dtype() -> numpy.dtype: ...

tests/stubs/python-3.12/pybind11-v2.11/demo/_bindings/numpy.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ __all__ = [
99
"accept_ndarray_int",
1010
"get_ndarray_float64",
1111
"get_ndarray_int",
12+
"return_dtype",
1213
]
1314

1415
def accept_ndarray_float64(
@@ -17,3 +18,4 @@ def accept_ndarray_float64(
1718
def accept_ndarray_int(arg0: typing.Annotated[numpy.ndarray, numpy.int32]) -> None: ...
1819
def get_ndarray_float64() -> typing.Annotated[numpy.ndarray, numpy.float64]: ...
1920
def get_ndarray_int() -> typing.Annotated[numpy.ndarray, numpy.int32]: ...
21+
def return_dtype() -> numpy.dtype: ...

tests/stubs/python-3.12/pybind11-v2.9/demo/_bindings/numpy.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ __all__ = [
99
"accept_ndarray_int",
1010
"get_ndarray_float64",
1111
"get_ndarray_int",
12+
"return_dtype",
1213
]
1314

1415
def accept_ndarray_float64(
@@ -17,3 +18,4 @@ def accept_ndarray_float64(
1718
def accept_ndarray_int(arg0: typing.Annotated[numpy.ndarray, numpy.int32]) -> None: ...
1819
def get_ndarray_float64() -> typing.Annotated[numpy.ndarray, numpy.float64]: ...
1920
def get_ndarray_int() -> typing.Annotated[numpy.ndarray, numpy.int32]: ...
21+
def return_dtype() -> numpy.dtype: ...

tests/stubs/python-3.7/pybind11-master/demo/_bindings/numpy.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ __all__ = [
88
"accept_ndarray_int",
99
"get_ndarray_float64",
1010
"get_ndarray_int",
11+
"return_dtype",
1112
]
1213

1314
def accept_ndarray_float64(
@@ -20,3 +21,4 @@ def get_ndarray_float64() -> typing_extensions.Annotated[
2021
numpy.ndarray, numpy.float64
2122
]: ...
2223
def get_ndarray_int() -> typing_extensions.Annotated[numpy.ndarray, numpy.int32]: ...
24+
def return_dtype() -> numpy.dtype: ...

tests/stubs/python-3.8/pybind11-master/demo/_bindings/numpy.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ __all__ = [
88
"accept_ndarray_int",
99
"get_ndarray_float64",
1010
"get_ndarray_int",
11+
"return_dtype",
1112
]
1213

1314
def accept_ndarray_float64(
@@ -20,3 +21,4 @@ def get_ndarray_float64() -> typing_extensions.Annotated[
2021
numpy.ndarray, numpy.float64
2122
]: ...
2223
def get_ndarray_int() -> typing_extensions.Annotated[numpy.ndarray, numpy.int32]: ...
24+
def return_dtype() -> numpy.dtype: ...

0 commit comments

Comments
 (0)