Skip to content

Commit c211941

Browse files
authored
Missing rename for numpy.ndarray.flags (#128)
- 🐛 Fix missing remap of `numpy.ndarray.flags` - ✨ Process `scipy.sparse.*` types the same as `numpy.ndarray` with `--numpy-array-wrap-with-annotated` - ✨ Support dynamic array size with `--numpy-array-wrap-with-annotated` - ❗️ Renamed CLI argument `--numpy-array-wrap-with-annotated-fixed-size` to `--numpy-array-wrap-with-annotated`
1 parent 6499225 commit c211941

File tree

19 files changed

+425
-84
lines changed

19 files changed

+425
-84
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
Changelog
22
=========
33

4+
Version 1.x (TBA)
5+
--------------------------
6+
Changes:
7+
8+
- 🐛 Fix missing remap of `numpy.ndarray.flags` (#128)
9+
- ✨ Process `scipy.sparse.*` types the same as `numpy.ndarray` with `--numpy-array-wrap-with-annotated` (#128)
10+
- ✨ Support dynamic array size with `--numpy-array-wrap-with-annotated` (#128)
11+
- ❗️ Renamed CLI argument `--numpy-array-wrap-with-annotated-fixed-size` to `--numpy-array-wrap-with-annotated` (#128)
12+
413

514
Version 1.2 (Aug 31, 2023)
615
--------------------------

pybind11_stubgen/__init__.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from pybind11_stubgen.parser.interface import IParser
1010
from pybind11_stubgen.parser.mixins.error_handlers import (
1111
IgnoreAllErrors,
12-
IgnoreFixedErrors,
1312
IgnoreInvalidExpressionErrors,
1413
IgnoreInvalidIdentifierErrors,
1514
IgnoreUnresolvedNameErrors,
@@ -32,6 +31,7 @@
3231
FixMissingImports,
3332
FixMissingNoneHashFieldAnnotation,
3433
FixNumpyArrayDimAnnotation,
34+
FixNumpyArrayFlags,
3535
FixNumpyArrayRemoveParameters,
3636
FixPEP585CollectionNames,
3737
FixPybind11EnumStrDoc,
@@ -109,11 +109,12 @@ def regex(pattern_str: str) -> re.Pattern:
109109

110110
numpy_array_fix = parser.add_mutually_exclusive_group()
111111
numpy_array_fix.add_argument(
112-
"--numpy-array-wrap-with-annotated-fixed-size",
112+
"--numpy-array-wrap-with-annotated",
113113
default=False,
114114
action="store_true",
115-
help="Replace 'numpy.ndarray[<TYPE>, [*DIMS]]' with "
116-
"'Annotated[numpy.ndarray, TYPE, FixedSize(*DIMS)]'",
115+
help="Replace numpy/scipy arrays of "
116+
"'ARRAY_T[<TYPE>, [*DIMS], *FLAGS]' format with "
117+
"'Annotated[ARRAY_T, TYPE, FixedSize|DynamicSize(*DIMS), *FLAGS]'",
117118
)
118119

119120
numpy_array_fix.add_argument(
@@ -155,23 +156,20 @@ def regex(pattern_str: str) -> re.Pattern:
155156

156157

157158
def stub_parser_from_args(args) -> IParser:
158-
error_handlers: list[type] = [
159+
error_handlers_top: list[type] = [
159160
*([IgnoreAllErrors] if args.ignore_all_errors else []),
160161
*([IgnoreInvalidIdentifierErrors] if args.ignore_invalid_identifiers else []),
161162
*([IgnoreInvalidExpressionErrors] if args.ignore_invalid_expressions else []),
162163
*([IgnoreUnresolvedNameErrors] if args.ignore_unresolved_names else []),
163-
IgnoreFixedErrors,
164+
]
165+
error_handlers_bottom: list[type] = [
164166
LogErrors,
165167
SuggestCxxSignatureFix,
166168
*([TerminateOnFatalErrors] if args.exit_code else []),
167169
]
168170

169171
numpy_fixes: list[type] = [
170-
*(
171-
[FixNumpyArrayDimAnnotation]
172-
if args.numpy_array_wrap_with_annotated_fixed_size
173-
else []
174-
),
172+
*([FixNumpyArrayDimAnnotation] if args.numpy_array_wrap_with_annotated else []),
175173
*(
176174
[FixNumpyArrayRemoveParameters]
177175
if args.numpy_array_remove_parameters
@@ -180,7 +178,7 @@ def stub_parser_from_args(args) -> IParser:
180178
]
181179

182180
class Parser(
183-
*error_handlers, # type: ignore[misc]
181+
*error_handlers_top, # type: ignore[misc]
184182
FixMissing__future__AnnotationsImport,
185183
FixMissing__all__Attribute,
186184
FixMissingNoneHashFieldAnnotation,
@@ -191,6 +189,7 @@ class Parser(
191189
FixMissingFixedSizeImport,
192190
FixMissingEnumMembersAnnotation,
193191
*numpy_fixes, # type: ignore[misc]
192+
FixNumpyArrayFlags,
194193
FixCurrentModulePrefixInTypeNames,
195194
FixBuiltinTypes,
196195
FilterClassMembers,
@@ -205,6 +204,7 @@ class Parser(
205204
ExtractSignaturesFromPybind11Docstrings,
206205
ParserDispatchMixin,
207206
BaseParser,
207+
*error_handlers_bottom, # type: ignore[misc]
208208
):
209209
pass
210210

pybind11_stubgen/parser/mixins/error_handlers.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,30 +43,6 @@ def report_error(self, error: ParserError) -> None:
4343
super().report_error(error)
4444

4545

46-
class IgnoreFixedErrors(IParser):
47-
def report_error(self, error: ParserError):
48-
if isinstance(error, NameResolutionError):
49-
if error.name[0] in ["pybind11_builtins", "PyCapsule", "module"]:
50-
return
51-
elif isinstance(error, InvalidExpressionError):
52-
if error.expression.startswith("FixedSize"):
53-
# https://github.com/pybind/pybind11/pull/4679
54-
return
55-
elif isinstance(error, InvalidIdentifierError):
56-
name = error.name
57-
if (
58-
name.startswith("ItemsView[")
59-
and name.endswith("]")
60-
or name.startswith("KeysView[")
61-
and name.endswith("]")
62-
or name.startswith("ValuesView[")
63-
and name.endswith("]")
64-
):
65-
return
66-
67-
super().report_error(error)
68-
69-
7046
class IgnoreUnresolvedNameErrors(IParser):
7147
def __init__(self):
7248
super().__init__()
@@ -154,8 +130,8 @@ def __init__(self):
154130
self.__found_fatal_errors = False
155131

156132
def report_error(self, error: ParserError):
157-
super().report_error(error)
158133
self.__found_fatal_errors = True
134+
super().report_error(error)
159135

160136
def finalize(self):
161137
super().finalize()

pybind11_stubgen/parser/mixins/filter.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,6 @@ def handle_attribute(self, path: QualifiedName, attr: Any) -> Attribute | None:
7474
return None
7575
return super().handle_attribute(path, attr)
7676

77-
def handle_bases(
78-
self, path: QualifiedName, bases: tuple[type, ...]
79-
) -> list[QualifiedName]:
80-
result = []
81-
for base in super().handle_bases(path, bases):
82-
if base[0] == "pybind11_builtins":
83-
break
84-
result.append(base)
85-
return result
86-
8777

8878
class FilterPybindInternals(IParser):
8979
__attribute_blacklist: set[Identifier] = {*map(Identifier, ("__entries",))}

pybind11_stubgen/parser/mixins/fix.py

Lines changed: 97 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import types
88
from typing import Any
99

10-
from pybind11_stubgen.parser.errors import NameResolutionError
10+
from pybind11_stubgen.parser.errors import NameResolutionError, ParserError
1111
from pybind11_stubgen.parser.interface import IParser
1212
from pybind11_stubgen.structs import (
1313
Alias,
@@ -27,7 +27,7 @@
2727
ResolvedType,
2828
Value,
2929
)
30-
from pybind11_stubgen.typing_ext import FixedSize
30+
from pybind11_stubgen.typing_ext import DynamicSize, FixedSize
3131

3232

3333
class RemoveSelfAnnotation(IParser):
@@ -252,6 +252,12 @@ def handle_type(self, type_: type) -> QualifiedName:
252252

253253
return result
254254

255+
def report_error(self, error: ParserError):
256+
if isinstance(error, NameResolutionError):
257+
if error.name[0] in ["PyCapsule"]:
258+
return
259+
super().report_error(error)
260+
255261

256262
class FixRedundantBuiltinsAnnotation(IParser):
257263
def handle_attribute(self, path: QualifiedName, attr: Any) -> Attribute | None:
@@ -451,7 +457,14 @@ def value_to_repr(self, value: Any) -> str:
451457

452458

453459
class FixNumpyArrayDimAnnotation(IParser):
454-
__ndarray_name = QualifiedName.from_str("numpy.ndarray")
460+
__array_names: set[QualifiedName] = {
461+
QualifiedName.from_str("numpy.ndarray"),
462+
*(
463+
QualifiedName.from_str(f"scipy.sparse.{storage}_{arr}")
464+
for storage in ["bsr", "coo", "csr", "csc", "dia", "dok", "lil"]
465+
for arr in ["array", "matrix"]
466+
),
467+
}
455468
__annotated_name = QualifiedName.from_str("typing.Annotated")
456469
numpy_primitive_types: set[QualifiedName] = set(
457470
map(
@@ -475,47 +488,79 @@ def parse_annotation_str(
475488
self, annotation_str: str
476489
) -> ResolvedType | InvalidExpression | Value:
477490
# Affects types of the following pattern:
478-
# numpy.ndarray[PRIMITIVE_TYPE[*DIMS]]
479-
# Annotated[numpy.ndarray, PRIMITIVE_TYPE, FixedSize[*DIMS]]
491+
# ARRAY_T[PRIMITIVE_TYPE[*DIMS], *FLAGS]
492+
# Replace with:
493+
# Annotated[ARRAY_T, PRIMITIVE_TYPE, FixedSize/DynamicSize[*DIMS], *FLAGS]
480494

481495
result = super().parse_annotation_str(annotation_str)
482496
if (
483497
not isinstance(result, ResolvedType)
484-
or result.name != self.__ndarray_name
498+
or result.name not in self.__array_names
485499
or result.parameters is None
486-
or len(result.parameters) != 1
500+
or len(result.parameters) == 0
487501
):
488502
return result
489-
param = result.parameters[0]
503+
504+
scalar_with_dims = result.parameters[0] # e.g. numpy.float64[32, 32]
505+
flags = result.parameters[1:]
506+
490507
if (
491-
not isinstance(param, ResolvedType)
492-
or param.name not in self.numpy_primitive_types
493-
or param.parameters is None
494-
or any(not isinstance(dim, Value) for dim in param.parameters)
508+
not isinstance(scalar_with_dims, ResolvedType)
509+
or scalar_with_dims.name not in self.numpy_primitive_types
510+
or (
511+
scalar_with_dims.parameters is not None
512+
and any(
513+
not isinstance(dim, Value) for dim in scalar_with_dims.parameters
514+
)
515+
)
495516
):
496517
return result
497518

498-
# isinstance check is redundant, but makes mypy happy
499-
dims = [int(dim.repr) for dim in param.parameters if isinstance(dim, Value)]
500-
501-
# override result with Annotated[...]
502519
result = ResolvedType(
503520
name=self.__annotated_name,
504521
parameters=[
505-
ResolvedType(self.__ndarray_name),
506-
ResolvedType(param.name),
522+
self.parse_annotation_str(str(result.name)),
523+
ResolvedType(scalar_with_dims.name),
507524
],
508525
)
509526

510-
if param.parameters is not None:
511-
# TRICK: Use `self.parse_type` to make `FixedSize`
512-
# properly added to the list of imports
513-
self.handle_type(FixedSize)
514-
assert result.parameters is not None
515-
result.parameters += [self.handle_value(FixedSize(*dims))]
527+
if (
528+
scalar_with_dims.parameters is not None
529+
and len(scalar_with_dims.parameters) >= 0
530+
):
531+
result.parameters += [
532+
self.handle_value(
533+
self._cook_dimension_parameters(scalar_with_dims.parameters)
534+
)
535+
]
536+
537+
result.parameters += flags
516538

517539
return result
518540

541+
def _cook_dimension_parameters(
542+
self, dimensions: list[Value]
543+
) -> FixedSize | DynamicSize:
544+
all_ints = True
545+
new_params = []
546+
for dim_param in dimensions:
547+
try:
548+
dim = int(dim_param.repr)
549+
except ValueError:
550+
dim = dim_param.repr
551+
all_ints = False
552+
new_params.append(dim)
553+
554+
if all_ints:
555+
return_t = FixedSize
556+
else:
557+
return_t = DynamicSize
558+
559+
# TRICK: Use `self.handle_type` to make `FixedSize`/`DynamicSize`
560+
# properly added to the list of imports
561+
self.handle_type(FixedSize)
562+
return return_t(*new_params)
563+
519564

520565
class FixNumpyArrayRemoveParameters(IParser):
521566
__ndarray_name = QualifiedName.from_str("numpy.ndarray")
@@ -529,6 +574,34 @@ def parse_annotation_str(
529574
return result
530575

531576

577+
class FixNumpyArrayFlags(IParser):
578+
__ndarray_name = QualifiedName.from_str("numpy.ndarray")
579+
__flags: set[QualifiedName] = {
580+
QualifiedName.from_str("flags.writeable"),
581+
QualifiedName.from_str("flags.c_contiguous"),
582+
QualifiedName.from_str("flags.f_contiguous"),
583+
}
584+
585+
def parse_annotation_str(
586+
self, annotation_str: str
587+
) -> ResolvedType | InvalidExpression | Value:
588+
result = super().parse_annotation_str(annotation_str)
589+
if isinstance(result, ResolvedType) and result.name == self.__ndarray_name:
590+
if result.parameters is not None:
591+
for param in result.parameters:
592+
if param.name in self.__flags:
593+
param.name = QualifiedName.from_str(
594+
f"numpy.ndarray.{param.name}"
595+
)
596+
597+
return result
598+
599+
def report_error(self, error: ParserError) -> None:
600+
if isinstance(error, NameResolutionError) and error.name in self.__flags:
601+
return
602+
super().report_error(error)
603+
604+
532605
class FixRedundantMethodsFromBuiltinObject(IParser):
533606
def handle_method(self, path: QualifiedName, method: Any) -> list[Method]:
534607
result = super().handle_method(path, method)

pybind11_stubgen/parser/mixins/parse.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
InvalidExpressionError,
1111
InvalidIdentifierError,
1212
NameResolutionError,
13+
ParserError,
1314
)
1415
from pybind11_stubgen.parser.interface import IParser
1516
from pybind11_stubgen.structs import (
@@ -379,6 +380,12 @@ def value_to_repr(self, value: Any) -> str:
379380
return "..."
380381
return repr(value)
381382

383+
def report_error(self, error: ParserError):
384+
if isinstance(error, NameResolutionError):
385+
if error.name[0] == "module":
386+
return
387+
super().report_error(error)
388+
382389
def _get_method_modifier(self, args: list[Argument]) -> Modifier:
383390
if len(args) == 0:
384391
return "static"
@@ -808,3 +815,31 @@ def _strip_empty_lines(self, doc_lines: list[str]) -> Docstring | None:
808815
if len(result) == 0:
809816
return None
810817
return Docstring(result)
818+
819+
def report_error(self, error: ParserError) -> None:
820+
if isinstance(error, NameResolutionError):
821+
if error.name[0] == "pybind11_builtins":
822+
return
823+
if isinstance(error, InvalidIdentifierError):
824+
name = error.name
825+
if (
826+
name.startswith("ItemsView[")
827+
and name.endswith("]")
828+
or name.startswith("KeysView[")
829+
and name.endswith("]")
830+
or name.startswith("ValuesView[")
831+
and name.endswith("]")
832+
):
833+
return
834+
835+
super().report_error(error)
836+
837+
def handle_bases(
838+
self, path: QualifiedName, bases: tuple[type, ...]
839+
) -> list[QualifiedName]:
840+
result = []
841+
for base in super().handle_bases(path, bases):
842+
if base[0] == "pybind11_builtins":
843+
break
844+
result.append(base)
845+
return result

0 commit comments

Comments
 (0)