31
31
32
32
"""
33
33
34
- from __future__ import annotations
35
-
36
- from typing import Final , TYPE_CHECKING , Callable
34
+ from typing import Final
37
35
38
36
import numpy as np
39
37
40
- if TYPE_CHECKING :
41
- from collections .abc import Iterable
42
-
43
- try :
44
- import mypy .types
45
- from mypy .types import Type
46
- from mypy .plugin import Plugin , AnalyzeTypeContext
47
- from mypy .nodes import MypyFile , ImportFrom , Statement
48
- from mypy .build import PRI_MED
49
-
50
- _HookFunc = Callable [[AnalyzeTypeContext ], Type ]
51
- MYPY_EX : None | ModuleNotFoundError = None
52
- except ModuleNotFoundError as ex :
53
- MYPY_EX = ex
54
-
55
- __all__ : list [str ] = []
38
+ __all__ = ()
56
39
57
40
58
41
def _get_precision_dict () -> dict [str , str ]:
@@ -70,11 +53,10 @@ def _get_precision_dict() -> dict[str, str]:
70
53
("_NBitDouble" , np .double ),
71
54
("_NBitLongDouble" , np .longdouble ),
72
55
]
73
- ret = {}
74
- module = "numpy._typing"
56
+ ret : dict [str , str ] = {}
75
57
for name , typ in names :
76
- n : int = 8 * typ () .dtype .itemsize
77
- ret [f' { module } ._nbit.{ name } ' ] = f"{ module } ._nbit_base._{ n } Bit"
58
+ n = 8 * np .dtype ( typ ) .itemsize
59
+ ret [f" { _MODULE } ._nbit.{ name } " ] = f"{ _MODULE } ._nbit_base._{ n } Bit"
78
60
return ret
79
61
80
62
@@ -97,16 +79,14 @@ def _get_extended_precision_list() -> list[str]:
97
79
98
80
def _get_c_intp_name () -> str :
99
81
# Adapted from `np.core._internal._getintp_ctype`
100
- char = np .dtype ('n' ).char
101
- if char == 'i' :
102
- return "c_int"
103
- elif char == 'l' :
104
- return "c_long"
105
- elif char == 'q' :
106
- return "c_longlong"
107
- else :
108
- return "c_long"
82
+ return {
83
+ "i" : "c_int" ,
84
+ "l" : "c_long" ,
85
+ "q" : "c_longlong" ,
86
+ }.get (np .dtype ("n" ).char , "c_long" )
87
+
109
88
89
+ _MODULE : Final = "numpy._typing"
110
90
111
91
#: A dictionary mapping type-aliases in `numpy._typing._nbit` to
112
92
#: concrete `numpy.typing.NBitBase` subclasses.
@@ -119,15 +99,30 @@ def _get_c_intp_name() -> str:
119
99
_C_INTP : Final = _get_c_intp_name ()
120
100
121
101
122
- def _hook (ctx : AnalyzeTypeContext ) -> Type :
123
- """Replace a type-alias with a concrete ``NBitBase`` subclass."""
124
- typ , _ , api = ctx
125
- name = typ .name .split ("." )[- 1 ]
126
- name_new = _PRECISION_DICT [f"numpy._typing._nbit.{ name } " ]
127
- return api .named_type (name_new )
102
+ try :
103
+ from collections .abc import Callable , Iterable
104
+ from typing import TYPE_CHECKING , TypeAlias , cast
105
+
106
+ if TYPE_CHECKING :
107
+ from mypy .typeanal import TypeAnalyser
108
+
109
+ import mypy .types
110
+ from mypy .plugin import Plugin , AnalyzeTypeContext
111
+ from mypy .nodes import MypyFile , ImportFrom , Statement
112
+ from mypy .build import PRI_MED
113
+
114
+
115
+ _HookFunc : TypeAlias = Callable [[AnalyzeTypeContext ], mypy .types .Type ]
116
+
117
+
118
+ def _hook (ctx : AnalyzeTypeContext ) -> mypy .types .Type :
119
+ """Replace a type-alias with a concrete ``NBitBase`` subclass."""
120
+ typ , _ , api = ctx
121
+ name = typ .name .split ("." )[- 1 ]
122
+ name_new = _PRECISION_DICT [f"{ _MODULE } ._nbit.{ name } " ]
123
+ return cast ("TypeAnalyser" , api ).named_type (name_new )
128
124
129
125
130
- if TYPE_CHECKING or MYPY_EX is None :
131
126
def _index (iterable : Iterable [Statement ], id : str ) -> int :
132
127
"""Identify the first ``ImportFrom`` instance the specified `id`."""
133
128
for i , value in enumerate (iterable ):
@@ -139,22 +134,23 @@ def _index(iterable: Iterable[Statement], id: str) -> int:
139
134
def _override_imports (
140
135
file : MypyFile ,
141
136
module : str ,
142
- imports : list [tuple [str , None | str ]],
137
+ imports : list [tuple [str , str | None ]],
143
138
) -> None :
144
139
"""Override the first `module`-based import with new `imports`."""
145
140
# Construct a new `from module import y` statement
146
141
import_obj = ImportFrom (module , 0 , names = imports )
147
142
import_obj .is_top_level = True
148
143
149
144
# Replace the first `module`-based import statement with `import_obj`
150
- for lst in [file .defs , file .imports ]: # type: list[Statement]
145
+ for lst in [file .defs , cast ( "list[Statement]" , file .imports )]:
151
146
i = _index (lst , module )
152
147
lst [i ] = import_obj
153
148
149
+
154
150
class _NumpyPlugin (Plugin ):
155
151
"""A mypy plugin for handling versus numpy-specific typing tasks."""
156
152
157
- def get_type_analyze_hook (self , fullname : str ) -> None | _HookFunc :
153
+ def get_type_analyze_hook (self , fullname : str ) -> _HookFunc | None :
158
154
"""Set the precision of platform-specific `numpy.number`
159
155
subclasses.
160
156
@@ -175,25 +171,27 @@ def get_additional_deps(
175
171
* Import the appropriate `ctypes` equivalent to `numpy.intp`.
176
172
177
173
"""
178
- ret = [(PRI_MED , file .fullname , - 1 )]
179
-
180
- if file .fullname == "numpy" :
174
+ fullname = file .fullname
175
+ if fullname == "numpy" :
181
176
_override_imports (
182
- file , "numpy._typing._extended_precision" ,
177
+ file ,
178
+ f"{ _MODULE } ._extended_precision" ,
183
179
imports = [(v , v ) for v in _EXTENDED_PRECISION_LIST ],
184
180
)
185
- elif file . fullname == "numpy.ctypeslib" :
181
+ elif fullname == "numpy.ctypeslib" :
186
182
_override_imports (
187
- file , "ctypes" ,
183
+ file ,
184
+ "ctypes" ,
188
185
imports = [(_C_INTP , "_c_intp" )],
189
186
)
190
- return ret
187
+ return [( PRI_MED , fullname , - 1 )]
191
188
192
- def plugin (version : str ) -> type [_NumpyPlugin ]:
189
+
190
+ def plugin (version : str ) -> type :
193
191
"""An entry-point for mypy."""
194
192
return _NumpyPlugin
195
193
196
- else :
197
- def plugin ( version : str ) -> type [ _NumpyPlugin ]:
198
- """An entry-point for mypy."""
199
- raise MYPY_EX
194
+ except ModuleNotFoundError as e :
195
+
196
+ def plugin ( version : str ) -> type :
197
+ raise e
0 commit comments