Skip to content

Commit 081c6dc

Browse files
committed
TYP: fix mypy errors in numpy.typing.mypy_plugin
1 parent 2e700c6 commit 081c6dc

File tree

1 file changed

+52
-54
lines changed

1 file changed

+52
-54
lines changed

numpy/typing/mypy_plugin.py

Lines changed: 52 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -31,28 +31,11 @@
3131
3232
"""
3333

34-
from __future__ import annotations
35-
36-
from typing import Final, TYPE_CHECKING, Callable
34+
from typing import Final
3735

3836
import numpy as np
3937

40-
if TYPE_CHECKING:
41-
from collections.abc import Iterable
42-
43-
try:
44-
import mypy.types
45-
from mypy.types import Type
46-
from mypy.plugin import Plugin, AnalyzeTypeContext
47-
from mypy.nodes import MypyFile, ImportFrom, Statement
48-
from mypy.build import PRI_MED
49-
50-
_HookFunc = Callable[[AnalyzeTypeContext], Type]
51-
MYPY_EX: None | ModuleNotFoundError = None
52-
except ModuleNotFoundError as ex:
53-
MYPY_EX = ex
54-
55-
__all__: list[str] = []
38+
__all__ = ()
5639

5740

5841
def _get_precision_dict() -> dict[str, str]:
@@ -70,11 +53,10 @@ def _get_precision_dict() -> dict[str, str]:
7053
("_NBitDouble", np.double),
7154
("_NBitLongDouble", np.longdouble),
7255
]
73-
ret = {}
74-
module = "numpy._typing"
56+
ret: dict[str, str] = {}
7557
for name, typ in names:
76-
n: int = 8 * typ().dtype.itemsize
77-
ret[f'{module}._nbit.{name}'] = f"{module}._nbit_base._{n}Bit"
58+
n = 8 * np.dtype(typ).itemsize
59+
ret[f"{_MODULE}._nbit.{name}"] = f"{_MODULE}._nbit_base._{n}Bit"
7860
return ret
7961

8062

@@ -97,16 +79,14 @@ def _get_extended_precision_list() -> list[str]:
9779

9880
def _get_c_intp_name() -> str:
9981
# Adapted from `np.core._internal._getintp_ctype`
100-
char = np.dtype('n').char
101-
if char == 'i':
102-
return "c_int"
103-
elif char == 'l':
104-
return "c_long"
105-
elif char == 'q':
106-
return "c_longlong"
107-
else:
108-
return "c_long"
82+
return {
83+
"i": "c_int",
84+
"l": "c_long",
85+
"q": "c_longlong",
86+
}.get(np.dtype("n").char, "c_long")
87+
10988

89+
_MODULE: Final = "numpy._typing"
11090

11191
#: A dictionary mapping type-aliases in `numpy._typing._nbit` to
11292
#: concrete `numpy.typing.NBitBase` subclasses.
@@ -119,15 +99,30 @@ def _get_c_intp_name() -> str:
11999
_C_INTP: Final = _get_c_intp_name()
120100

121101

122-
def _hook(ctx: AnalyzeTypeContext) -> Type:
123-
"""Replace a type-alias with a concrete ``NBitBase`` subclass."""
124-
typ, _, api = ctx
125-
name = typ.name.split(".")[-1]
126-
name_new = _PRECISION_DICT[f"numpy._typing._nbit.{name}"]
127-
return api.named_type(name_new)
102+
try:
103+
from collections.abc import Callable, Iterable
104+
from typing import TYPE_CHECKING, TypeAlias, cast
105+
106+
if TYPE_CHECKING:
107+
from mypy.typeanal import TypeAnalyser
108+
109+
import mypy.types
110+
from mypy.plugin import Plugin, AnalyzeTypeContext
111+
from mypy.nodes import MypyFile, ImportFrom, Statement
112+
from mypy.build import PRI_MED
113+
114+
115+
_HookFunc: TypeAlias = Callable[[AnalyzeTypeContext], mypy.types.Type]
116+
117+
118+
def _hook(ctx: AnalyzeTypeContext) -> mypy.types.Type:
119+
"""Replace a type-alias with a concrete ``NBitBase`` subclass."""
120+
typ, _, api = ctx
121+
name = typ.name.split(".")[-1]
122+
name_new = _PRECISION_DICT[f"{_MODULE}._nbit.{name}"]
123+
return cast("TypeAnalyser", api).named_type(name_new)
128124

129125

130-
if TYPE_CHECKING or MYPY_EX is None:
131126
def _index(iterable: Iterable[Statement], id: str) -> int:
132127
"""Identify the first ``ImportFrom`` instance the specified `id`."""
133128
for i, value in enumerate(iterable):
@@ -139,22 +134,23 @@ def _index(iterable: Iterable[Statement], id: str) -> int:
139134
def _override_imports(
140135
file: MypyFile,
141136
module: str,
142-
imports: list[tuple[str, None | str]],
137+
imports: list[tuple[str, str | None]],
143138
) -> None:
144139
"""Override the first `module`-based import with new `imports`."""
145140
# Construct a new `from module import y` statement
146141
import_obj = ImportFrom(module, 0, names=imports)
147142
import_obj.is_top_level = True
148143

149144
# Replace the first `module`-based import statement with `import_obj`
150-
for lst in [file.defs, file.imports]: # type: list[Statement]
145+
for lst in [file.defs, cast("list[Statement]", file.imports)]:
151146
i = _index(lst, module)
152147
lst[i] = import_obj
153148

149+
154150
class _NumpyPlugin(Plugin):
155151
"""A mypy plugin for handling versus numpy-specific typing tasks."""
156152

157-
def get_type_analyze_hook(self, fullname: str) -> None | _HookFunc:
153+
def get_type_analyze_hook(self, fullname: str) -> _HookFunc | None:
158154
"""Set the precision of platform-specific `numpy.number`
159155
subclasses.
160156
@@ -175,25 +171,27 @@ def get_additional_deps(
175171
* Import the appropriate `ctypes` equivalent to `numpy.intp`.
176172
177173
"""
178-
ret = [(PRI_MED, file.fullname, -1)]
179-
180-
if file.fullname == "numpy":
174+
fullname = file.fullname
175+
if fullname == "numpy":
181176
_override_imports(
182-
file, "numpy._typing._extended_precision",
177+
file,
178+
f"{_MODULE}._extended_precision",
183179
imports=[(v, v) for v in _EXTENDED_PRECISION_LIST],
184180
)
185-
elif file.fullname == "numpy.ctypeslib":
181+
elif fullname == "numpy.ctypeslib":
186182
_override_imports(
187-
file, "ctypes",
183+
file,
184+
"ctypes",
188185
imports=[(_C_INTP, "_c_intp")],
189186
)
190-
return ret
187+
return [(PRI_MED, fullname, -1)]
191188

192-
def plugin(version: str) -> type[_NumpyPlugin]:
189+
190+
def plugin(version: str) -> type:
193191
"""An entry-point for mypy."""
194192
return _NumpyPlugin
195193

196-
else:
197-
def plugin(version: str) -> type[_NumpyPlugin]:
198-
"""An entry-point for mypy."""
199-
raise MYPY_EX
194+
except ModuleNotFoundError as e:
195+
196+
def plugin(version: str) -> type:
197+
raise e

0 commit comments

Comments
 (0)