Skip to content

Commit 60728d3

Browse files
authored
feat: ✨ numpy-array-use-type-var flag (#188)
1 parent af3c6af commit 60728d3

File tree

178 files changed

+1158
-24
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

178 files changed

+1158
-24
lines changed

.github/workflows/ci.yml

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
if: ${{ failure() || success() }}
3939

4040
tests:
41-
name: "Test 🐍 ${{ matrix.python }} • pybind-${{ matrix.pybind11-branch }}"
41+
name: "Test 🐍 ${{ matrix.python }} • pybind-${{ matrix.pybind11-branch }} • ${{ matrix.numpy-format }}"
4242
runs-on: ubuntu-latest
4343
strategy:
4444
fail-fast: false
@@ -52,11 +52,18 @@ jobs:
5252
- "3.9"
5353
- "3.8"
5454
- "3.7"
55+
numpy-format:
56+
- "numpy-array-wrap-with-annotated"
5557
include:
56-
- pybind11-branch: "v2.9"
57-
python: "3.12"
58-
- pybind11-branch: "v2.11"
59-
python: "3.12"
58+
- python: "3.12"
59+
pybind11-branch: "v2.9"
60+
numpy-format: "numpy-array-wrap-with-annotated"
61+
- python: "3.12"
62+
pybind11-branch: "v2.11"
63+
numpy-format: "numpy-array-wrap-with-annotated"
64+
- python: "3.12"
65+
pybind11-branch: "master"
66+
numpy-format: "numpy-array-use-type-var"
6067
steps:
6168
- uses: actions/checkout@v3
6269

@@ -84,7 +91,7 @@ jobs:
8491

8592
- name: Check stubs generation
8693
shell: bash
87-
run: ./tests/check-demo-stubs-generation.sh --stubs-sub-dir "stubs/python-${{ matrix.python }}/pybind11-${{ matrix.pybind11-branch }}"
94+
run: ./tests/check-demo-stubs-generation.sh --stubs-sub-dir "stubs/python-${{ matrix.python }}/pybind11-${{ matrix.pybind11-branch }}/${{ matrix.numpy-format }}" --${{ matrix.numpy-format }}
8895

8996
- name: Archive patch
9097
uses: actions/upload-artifact@v3
@@ -137,6 +144,7 @@ jobs:
137144
pybind11-stubgen "${{ matrix.test-package }}" -o flavour-1 --numpy-array-wrap-with-annotated
138145
pybind11-stubgen "${{ matrix.test-package }}" -o flavour-2 --numpy-array-remove-parameters
139146
pybind11-stubgen "${{ matrix.test-package }}" -o flavour-3 --print-invalid-expressions-as-is
147+
pybind11-stubgen "${{ matrix.test-package }}" -o flavour-4 --numpy-array-use-type-var
140148
pybind11-stubgen "${{ matrix.test-package }}" --dry-run
141149
142150
publish:

pybind11_stubgen/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@
3333
FixMissingImports,
3434
FixMissingNoneHashFieldAnnotation,
3535
FixNumpyArrayDimAnnotation,
36+
FixNumpyArrayDimTypeVar,
3637
FixNumpyArrayFlags,
3738
FixNumpyArrayRemoveParameters,
3839
FixNumpyDtype,
3940
FixPEP585CollectionNames,
4041
FixPybind11EnumStrDoc,
4142
FixRedundantBuiltinsAnnotation,
4243
FixRedundantMethodsFromBuiltinObject,
44+
FixScipyTypeArguments,
4345
FixTypingTypeNames,
4446
FixValueReprRandomAddress,
4547
OverridePrintSafeValues,
@@ -66,6 +68,7 @@ class CLIArgs(Namespace):
6668
ignore_all_errors: bool
6769
enum_class_locations: list[tuple[re.Pattern, str]]
6870
numpy_array_wrap_with_annotated: bool
71+
numpy_array_use_type_var: bool
6972
numpy_array_remove_parameters: bool
7073
print_invalid_expressions_as_is: bool
7174
print_safe_value_reprs: re.Pattern | None
@@ -156,6 +159,13 @@ def regex_colon_path(regex_path: str) -> tuple[re.Pattern, str]:
156159
"'ARRAY_T[TYPE, [*DIMS], *FLAGS]' format with "
157160
"'Annotated[ARRAY_T, TYPE, FixedSize|DynamicSize(*DIMS), *FLAGS]'",
158161
)
162+
numpy_array_fix.add_argument(
163+
"--numpy-array-use-type-var",
164+
default=False,
165+
action="store_true",
166+
help="Replace 'numpy.ndarray[numpy.float32[m, 1]]' with "
167+
"'numpy.ndarray[tuple[M, typing.Literal[1]], numpy.dtype[numpy.float32]]'",
168+
)
159169

160170
numpy_array_fix.add_argument(
161171
"--numpy-array-remove-parameters",
@@ -230,6 +240,7 @@ def stub_parser_from_args(args: CLIArgs) -> IParser:
230240

231241
numpy_fixes: list[type] = [
232242
*([FixNumpyArrayDimAnnotation] if args.numpy_array_wrap_with_annotated else []),
243+
*([FixNumpyArrayDimTypeVar] if args.numpy_array_use_type_var else []),
233244
*(
234245
[FixNumpyArrayRemoveParameters]
235246
if args.numpy_array_remove_parameters
@@ -246,6 +257,7 @@ class Parser(
246257
FilterTypingModuleAttributes,
247258
FixPEP585CollectionNames,
248259
FixTypingTypeNames,
260+
FixScipyTypeArguments,
249261
FixMissingFixedSizeImport,
250262
FixMissingEnumMembersAnnotation,
251263
OverridePrintSafeValues,

pybind11_stubgen/parser/mixins/fix.py

Lines changed: 174 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Property,
3232
QualifiedName,
3333
ResolvedType,
34+
TypeVar_,
3435
Value,
3536
)
3637
from pybind11_stubgen.typing_ext import DynamicSize, FixedSize
@@ -335,6 +336,7 @@ class FixTypingTypeNames(IParser):
335336
"Iterator",
336337
"KeysView",
337338
"List",
339+
"Literal",
338340
"Optional",
339341
"Sequence",
340342
"Set",
@@ -360,12 +362,28 @@ def __init__(self):
360362
super().__init__()
361363
if sys.version_info < (3, 9):
362364
self.__typing_extensions_names.add(Identifier("Annotated"))
365+
if sys.version_info < (3, 8):
366+
self.__typing_extensions_names.add(Identifier("Literal"))
363367

364368
def parse_annotation_str(
365369
self, annotation_str: str
366370
) -> ResolvedType | InvalidExpression | Value:
367371
result = super().parse_annotation_str(annotation_str)
368-
if not isinstance(result, ResolvedType) or len(result.name) != 1:
372+
return self._parse_annotation_str(result)
373+
374+
def _parse_annotation_str(
375+
self, result: ResolvedType | InvalidExpression | Value
376+
) -> ResolvedType | InvalidExpression | Value:
377+
if not isinstance(result, ResolvedType):
378+
return result
379+
380+
result.parameters = (
381+
[self._parse_annotation_str(p) for p in result.parameters]
382+
if result.parameters is not None
383+
else None
384+
)
385+
386+
if len(result.name) != 1:
369387
return result
370388

371389
word = result.name[0]
@@ -582,6 +600,136 @@ def report_error(self, error: ParserError) -> None:
582600
super().report_error(error)
583601

584602

603+
class FixNumpyArrayDimTypeVar(IParser):
604+
__array_names: set[QualifiedName] = {QualifiedName.from_str("numpy.ndarray")}
605+
numpy_primitive_types = FixNumpyArrayDimAnnotation.numpy_primitive_types
606+
607+
__DIM_VARS: set[str] = set()
608+
609+
def handle_module(
610+
self, path: QualifiedName, module: types.ModuleType
611+
) -> Module | None:
612+
result = super().handle_module(path, module)
613+
if result is None:
614+
return None
615+
616+
if self.__DIM_VARS:
617+
# the TypeVar_'s generated code will reference `typing`
618+
result.imports.add(
619+
Import(name=None, origin=QualifiedName.from_str("typing"))
620+
)
621+
622+
for name in self.__DIM_VARS:
623+
result.type_vars.append(
624+
TypeVar_(
625+
name=Identifier(name),
626+
bound=self.parse_annotation_str("int"),
627+
),
628+
)
629+
630+
self.__DIM_VARS.clear()
631+
632+
return result
633+
634+
def parse_annotation_str(
635+
self, annotation_str: str
636+
) -> ResolvedType | InvalidExpression | Value:
637+
# Affects types of the following pattern:
638+
# numpy.ndarray[PRIMITIVE_TYPE[*DIMS], *FLAGS]
639+
# Replace with:
640+
# numpy.ndarray[tuple[M, Literal[1]], numpy.dtype[numpy.float32]]
641+
642+
result = super().parse_annotation_str(annotation_str)
643+
644+
if not isinstance(result, ResolvedType):
645+
return result
646+
647+
# handle unqualified, single-letter annotation as a TypeVar
648+
if len(result.name) == 1 and len(result.name[0]) == 1:
649+
result.name = QualifiedName.from_str(result.name[0].upper())
650+
self.__DIM_VARS.add(result.name[0])
651+
652+
if result.name not in self.__array_names:
653+
return result
654+
655+
# ndarray is generic and should have 2 type arguments
656+
if result.parameters is None or len(result.parameters) == 0:
657+
result.parameters = [
658+
self.parse_annotation_str("Any"),
659+
ResolvedType(
660+
name=QualifiedName.from_str("numpy.dtype"),
661+
parameters=[self.parse_annotation_str("Any")],
662+
),
663+
]
664+
return result
665+
666+
scalar_with_dims = result.parameters[0] # e.g. numpy.float64[32, 32]
667+
668+
if (
669+
not isinstance(scalar_with_dims, ResolvedType)
670+
or scalar_with_dims.name not in self.numpy_primitive_types
671+
):
672+
return result
673+
674+
dtype = ResolvedType(
675+
name=QualifiedName.from_str("numpy.dtype"),
676+
parameters=[ResolvedType(name=scalar_with_dims.name)],
677+
)
678+
679+
shape = self.parse_annotation_str("Any")
680+
if (
681+
scalar_with_dims.parameters is not None
682+
and len(scalar_with_dims.parameters) > 0
683+
):
684+
dims = self.__to_dims(scalar_with_dims.parameters)
685+
if dims is not None:
686+
shape = self.parse_annotation_str("Tuple")
687+
assert isinstance(shape, ResolvedType)
688+
shape.parameters = []
689+
for dim in dims:
690+
if isinstance(dim, int):
691+
# self.parse_annotation_str will qualify Literal with either
692+
# typing or typing_extensions and add the import to the module
693+
literal_dim = self.parse_annotation_str("Literal")
694+
assert isinstance(literal_dim, ResolvedType)
695+
literal_dim.parameters = [Value(repr=str(dim))]
696+
shape.parameters.append(literal_dim)
697+
else:
698+
shape.parameters.append(
699+
ResolvedType(name=QualifiedName.from_str(dim))
700+
)
701+
702+
result.parameters = [shape, dtype]
703+
return result
704+
705+
def __to_dims(
706+
self, dimensions: Sequence[ResolvedType | Value | InvalidExpression]
707+
) -> list[int | str] | None:
708+
result: list[int | str] = []
709+
for dim_param in dimensions:
710+
if isinstance(dim_param, Value):
711+
try:
712+
dim = int(dim_param.repr)
713+
except ValueError:
714+
return None
715+
elif isinstance(dim_param, ResolvedType):
716+
dim = str(dim_param)
717+
else:
718+
return None
719+
result.append(dim)
720+
return result
721+
722+
def report_error(self, error: ParserError) -> None:
723+
if (
724+
isinstance(error, NameResolutionError)
725+
and len(error.name) == 1
726+
and error.name[0] in self.__DIM_VARS
727+
):
728+
# allow type variables, which are manually resolved in `handle_module`
729+
return
730+
super().report_error(error)
731+
732+
585733
class FixNumpyArrayRemoveParameters(IParser):
586734
__ndarray_name = QualifiedName.from_str("numpy.ndarray")
587735

@@ -594,24 +742,40 @@ def parse_annotation_str(
594742
return result
595743

596744

745+
class FixScipyTypeArguments(IParser):
746+
def parse_annotation_str(
747+
self, annotation_str: str
748+
) -> ResolvedType | InvalidExpression | Value:
749+
result = super().parse_annotation_str(annotation_str)
750+
751+
if not isinstance(result, ResolvedType):
752+
return result
753+
754+
# scipy.sparse arrays/matrices are not currently generic and do not accept type
755+
# arguments
756+
if result.name[:2] == ("scipy", "sparse"):
757+
result.parameters = None
758+
759+
return result
760+
761+
597762
class FixNumpyDtype(IParser):
598763
__numpy_dtype = QualifiedName.from_str("numpy.dtype")
599764

600765
def parse_annotation_str(
601766
self, annotation_str: str
602767
) -> ResolvedType | InvalidExpression | Value:
603768
result = super().parse_annotation_str(annotation_str)
604-
if (
605-
not isinstance(result, ResolvedType)
606-
or len(result.name) != 1
607-
or result.parameters is not None
608-
):
609-
return result
610769

611-
word = result.name[0]
612-
if word != Identifier("dtype"):
770+
if not isinstance(result, ResolvedType) or result.parameters:
613771
return result
614-
return ResolvedType(name=self.__numpy_dtype)
772+
773+
# numpy.dtype is generic and should have a type argument
774+
if result.name[:1] == ("dtype",) or result.name[:2] == ("numpy", "dtype"):
775+
result.name = self.__numpy_dtype
776+
result.parameters = [self.parse_annotation_str("Any")]
777+
778+
return result
615779

616780

617781
class FixNumpyArrayFlags(IParser):

pybind11_stubgen/parser/mixins/parse.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Property,
3232
QualifiedName,
3333
ResolvedType,
34+
TypeVar_,
3435
Value,
3536
)
3637

@@ -103,6 +104,8 @@ def handle_module(
103104
result.sub_modules.append(obj)
104105
elif isinstance(obj, Attribute):
105106
result.attributes.append(obj)
107+
elif isinstance(obj, TypeVar_):
108+
result.type_vars.append(obj)
106109
elif obj is None:
107110
pass
108111
else:

pybind11_stubgen/printer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Module,
2121
Property,
2222
ResolvedType,
23+
TypeVar_,
2324
Value,
2425
)
2526

@@ -81,6 +82,9 @@ def print_class(self, class_: Class) -> list[str]:
8182
*indent_lines(self.print_class_body(class_)),
8283
]
8384

85+
def print_type_var(self, type_var: TypeVar_) -> list[str]:
86+
return [str(type_var)]
87+
8488
def print_class_body(self, class_: Class) -> list[str]:
8589
result = []
8690
if class_.doc is not None:
@@ -215,6 +219,9 @@ def print_module(self, module: Module) -> list[str]:
215219
result.extend(self.print_attribute(attr))
216220
break
217221

222+
for type_var in sorted(module.type_vars, key=lambda t: t.name):
223+
result.extend(self.print_type_var(type_var))
224+
218225
for class_ in sorted(module.classes, key=lambda c: c.name):
219226
result.extend(self.print_class(class_))
220227

0 commit comments

Comments
 (0)