Skip to content

Commit 7508dcb

Browse files
authored
🐛 sparse: add the missing sparse array/matrix dunder methods (#391)
2 parents 3945a95 + e80b5e8 commit 7508dcb

File tree

8 files changed

+378
-163
lines changed

8 files changed

+378
-163
lines changed

.mypyignore

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@ scipy\.fftpack\.(_?helper\.)?fftfreq
1414

1515
# accidental implicit exports of internal scipy machinery
1616
scipy\._lib\.decorator\.(DEF|ArgSpec|FunctionMaker|__init__|append|dispatch_on|get_init|getargspec|init|n_args)
17-
scipy\.special\._precompute\..* # TODO??
1817
scipy\.special\.libsf_error_state
1918
scipy\.stats\._rcont\.rcont
2019

20+
# why is this even included in the wheels?
21+
scipy\.special\._precompute\..*
22+
2123
# omitted methods that always return `NotImplemented` or always raise
22-
scipy\.sparse\._(\w+)\._(\w+)\.__(len|i(add|mul|sub)|(i|r)(true)?div)__
24+
# scipy\.sparse\._(\w+)\._(\w+)\.__(len|i(add|mul|sub)|(i|r)(true)?div)__
2325

2426
# workarounds for mypy bugs
2527
scipy\.signal\._short_time_fft\.(FFT_MODE_TYPE|PAD_TYPE) # `Literal[...] != def (*, **)`

scipy-stubs/sparse/_base.pyi

Lines changed: 119 additions & 129 deletions
Large diffs are not rendered by default.

scipy-stubs/sparse/_csc.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import Any, Generic, Literal, overload
22
from typing_extensions import TypeIs, TypeVar, override
33

4+
import numpy as np
45
import optype as op
6+
import optype.numpy as onp
57
from ._base import sparray
68
from ._compressed import _cs_matrix
79
from ._matrix import spmatrix
@@ -24,6 +26,12 @@ class _csc_base(_cs_matrix[_SCT, tuple[int, int]], Generic[_SCT]):
2426
@override
2527
def shape(self, /) -> tuple[int, int]: ...
2628

29+
#
30+
@overload # type: ignore[explicit-override]
31+
def count_nonzero(self, /, axis: None = None) -> int: ...
32+
@overload
33+
def count_nonzero(self, /, axis: op.CanIndex) -> onp.Array1D[np.intp]: ...
34+
2735
class csc_array(_csc_base[_SCT], sparray, Generic[_SCT]): ...
2836

2937
class csc_matrix(_csc_base[_SCT], spmatrix[_SCT], Generic[_SCT]): # type: ignore[misc]

scipy-stubs/sparse/_csr.pyi

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import Any, Generic, Literal, overload
22
from typing_extensions import TypeIs, TypeVar, override
33

4+
import numpy as np
45
import optype as op
6+
import optype.numpy as onp
57
from ._base import sparray
68
from ._compressed import _cs_matrix
79
from ._matrix import spmatrix
@@ -22,6 +24,16 @@ class _csr_base(_cs_matrix[_SCT, _ShapeT_co], Generic[_SCT, _ShapeT_co]):
2224
@override
2325
def format(self, /) -> Literal["csr"]: ...
2426

27+
#
28+
@overload # type: ignore[explicit-override]
29+
def count_nonzero(self, /, axis: None = None) -> int: ...
30+
@overload
31+
def count_nonzero(self: _csr_base[Any, tuple[int]], /, axis: op.CanIndex) -> int: ...
32+
@overload
33+
def count_nonzero(self: _csr_base[Any, tuple[int, int]], /, axis: op.CanIndex) -> onp.Array1D[np.intp]: ...
34+
@overload
35+
def count_nonzero(self: csr_array, /, axis: op.CanIndex) -> int | onp.Array1D[np.intp]: ... # type: ignore[misc]
36+
2537
class csr_array(_csr_base[_SCT, _ShapeT_co], sparray, Generic[_SCT, _ShapeT_co]): ...
2638

2739
class csr_matrix(_csr_base[_SCT, tuple[int, int]], spmatrix[_SCT], Generic[_SCT]): # type: ignore[misc]

scipy-stubs/sparse/_data.pyi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,10 @@ class _data_matrix(_spbase[_SCT_co, _ShapeT_co], Generic[_SCT_co, _ShapeT_co]):
7676
def __init__(self, /, arg1: onp.CanArrayND[_SCT_co], *, maxprint: int | None = None) -> None: ...
7777

7878
#
79-
def __imul__(self, rhs: _ScalarLike, /) -> Self: ... # type: ignore[misc,override]
80-
def __itruediv__(self, rhs: _ScalarLike, /) -> Self: ... # type: ignore[misc,override]
79+
@override
80+
def __imul__(self, rhs: _ScalarLike, /) -> Self: ... # type: ignore[override]
81+
@override
82+
def __itruediv__(self, rhs: _ScalarLike, /) -> Self: ... # type: ignore[override]
8183

8284
# NOTE: The following methods do not convert the scalar type
8385
def sign(self, /) -> Self: ...

scipy-stubs/sparse/_dok.pyi

Lines changed: 193 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
# mypy: disable-error-code="misc, override"
1+
# NOTE: Adding `@override` to `@overload`ed methods will crash stubtest (basedmypy 1.13.0)
2+
# mypy: disable-error-code="misc, override, explicit-override"
23
# pyright: reportIncompatibleMethodOverride=false
34

45
from collections.abc import Iterable, Sequence
56
from typing import Any, Generic, Literal, TypeAlias, overload
67
from typing_extensions import Never, Self, TypeIs, TypeVar, override
78

89
import numpy as np
10+
import optype as op
911
import optype.numpy as onp
1012
import optype.typing as opt
1113
from ._base import _spbase, sparray
@@ -15,14 +17,25 @@ from ._typing import Scalar, ShapeDOK, ToShape1dNd
1517

1618
__all__ = ["dok_array", "dok_matrix", "isspmatrix_dok"]
1719

20+
###
21+
1822
_T = TypeVar("_T")
1923
_SCT = TypeVar("_SCT", bound=Scalar, default=Any)
2024
_ShapeT_co = TypeVar("_ShapeT_co", bound=ShapeDOK, default=ShapeDOK, covariant=True)
2125

26+
_1D: TypeAlias = tuple[int] # noqa: PYI042
27+
_2D: TypeAlias = tuple[int, int] # noqa: PYI042
28+
2229
_ToDType: TypeAlias = type[_SCT] | np.dtype[_SCT] | onp.HasDType[np.dtype[_SCT]]
2330
_ToMatrix: TypeAlias = _spbase[_SCT] | onp.CanArrayND[_SCT] | Sequence[onp.CanArrayND[_SCT]] | _ToMatrixPy[_SCT]
2431
_ToMatrixPy: TypeAlias = Sequence[_T] | Sequence[Sequence[_T]]
2532

33+
_ToKey1D: TypeAlias = onp.ToJustInt | tuple[onp.ToJustInt]
34+
_ToKey2D: TypeAlias = tuple[onp.ToJustInt, onp.ToJustInt]
35+
36+
_ToKeys1D: TypeAlias = Iterable[_ToKey1D]
37+
_ToKeys2D: TypeAlias = Iterable[_ToKey2D]
38+
2639
###
2740

2841
class _dok_base(_spbase[_SCT, _ShapeT_co], IndexMixin[_SCT, _ShapeT_co], dict[ShapeDOK, _SCT], Generic[_SCT, _ShapeT_co]):
@@ -31,11 +44,9 @@ class _dok_base(_spbase[_SCT, _ShapeT_co], IndexMixin[_SCT, _ShapeT_co], dict[Sh
3144
@property
3245
@override
3346
def format(self, /) -> Literal["dok"]: ...
34-
#
3547
@property
3648
@override
3749
def ndim(self, /) -> Literal[1, 2]: ...
38-
#
3950
@property
4051
@override
4152
def shape(self, /) -> _ShapeT_co: ...
@@ -129,10 +140,20 @@ class _dok_base(_spbase[_SCT, _ShapeT_co], IndexMixin[_SCT, _ShapeT_co], dict[Sh
129140
copy: bool = False,
130141
maxprint: int | None = None,
131142
) -> None: ...
143+
@override
144+
def todok(self, /, copy: bool = False) -> Self: ...
132145

133146
#
134147
@override
135-
def __delitem__(self, key: onp.ToJustInt, /) -> None: ...
148+
def __len__(self, /) -> int: ...
149+
150+
#
151+
@overload
152+
def __delitem__(self: _dok_base[Any, _2D], key: _ToKey2D, /) -> None: ...
153+
@overload
154+
def __delitem__(self: _dok_base[Any, _1D], key: _ToKey1D, /) -> None: ...
155+
@overload
156+
def __delitem__(self, key: _ToKey1D | _ToKey2D, /) -> None: ...
136157

137158
#
138159
@override
@@ -141,27 +162,183 @@ class _dok_base(_spbase[_SCT, _ShapeT_co], IndexMixin[_SCT, _ShapeT_co], dict[Sh
141162
def __ror__(self, other: Never, /) -> Never: ...
142163
@override
143164
def __ior__(self, other: Never, /) -> Never: ... # noqa: PYI034
165+
166+
#
167+
@overload
168+
def count_nonzero(self, /, axis: None = None) -> int: ...
169+
@overload
170+
def count_nonzero(self, /, axis: op.CanIndex) -> onp.Array1D[np.intp]: ...
171+
172+
#
144173
@override
145174
def update(self, /, val: Never) -> Never: ...
146175

147-
# TODO(jorenham)
148-
@override
149-
def get(self, key: onp.ToJustInt | ShapeDOK, /, default: onp.ToComplex = 0.0) -> _SCT: ...
150-
@override
151-
def setdefault(self, key: onp.ToJustInt | ShapeDOK, default: onp.ToComplex | None = None, /) -> _SCT: ...
152-
@classmethod
153-
@override
154-
def fromkeys(cls, iterable: Iterable[ShapeDOK], value: int = 1, /) -> Self: ...
176+
#
177+
@overload
178+
def setdefault(self: _dok_base[Any, _2D], key: _ToKey2D, default: _T, /) -> _SCT | _T: ...
179+
@overload
180+
def setdefault(self: _dok_base[Any, _2D], key: _ToKey2D, default: None = None, /) -> _SCT | None: ...
181+
@overload
182+
def setdefault(self: _dok_base[Any, _1D], key: _ToKey1D, default: _T, /) -> _SCT | _T: ...
183+
@overload
184+
def setdefault(self: _dok_base[Any, _1D], key: _ToKey1D, default: None = None, /) -> _SCT | None: ...
185+
@overload
186+
def setdefault(self, key: _ToKey1D | _ToKey2D, default: _T, /) -> _SCT | _T: ...
187+
@overload
188+
def setdefault(self, key: _ToKey1D | _ToKey2D, default: None = None, /) -> _SCT | None: ...
189+
190+
#
191+
@overload
192+
def get(self: _dok_base[Any, _2D], /, key: _ToKey2D, default: _T) -> _SCT | _T: ...
193+
@overload
194+
def get(self: _dok_base[Any, _2D], /, key: _ToKey2D, default: float = 0.0) -> _SCT | float: ...
195+
@overload
196+
def get(self: _dok_base[Any, _1D], /, key: _ToKey1D, default: _T) -> _SCT | _T: ...
197+
@overload
198+
def get(self: _dok_base[Any, _1D], /, key: _ToKey1D, default: float = 0.0) -> _SCT | float: ...
199+
@overload
200+
def get(self, /, key: _ToKey1D | _ToKey2D, default: _T) -> _SCT | _T: ...
201+
@overload
202+
def get(self, /, key: _ToKey1D | _ToKey2D, default: float = 0.0) -> _SCT | float: ...
155203

156204
#
157205
def conjtransp(self, /) -> Self: ...
158206

159-
class dok_array(_dok_base[_SCT, _ShapeT_co], sparray, Generic[_SCT, _ShapeT_co]): ...
207+
#
208+
@overload
209+
@classmethod
210+
def fromkeys(cls: type[_dok_base[_SCT, _2D]], iterable: _ToKeys2D, v: _SCT, /) -> _dok_base[_SCT, _2D]: ...
211+
@overload
212+
@classmethod
213+
def fromkeys(cls: type[_dok_base[_SCT, _1D]], iterable: _ToKeys1D, v: _SCT, /) -> _dok_base[_SCT, _1D]: ...
214+
@overload
215+
@classmethod
216+
def fromkeys(cls: type[_dok_base[np.bool_, _2D]], iterable: _ToKeys2D, v: onp.ToBool, /) -> _dok_base[np.bool_, _2D]: ...
217+
@overload
218+
@classmethod
219+
def fromkeys(cls: type[_dok_base[np.bool_, _1D]], iterable: _ToKeys1D, v: onp.ToBool, /) -> _dok_base[np.bool_, _1D]: ...
220+
@overload
221+
@classmethod
222+
def fromkeys(cls: type[_dok_base[np.int_, _2D]], iterable: _ToKeys2D, v: opt.JustInt = 1, /) -> _dok_base[np.int_, _2D]: ...
223+
@overload
224+
@classmethod
225+
def fromkeys(cls: type[_dok_base[np.int_, _1D]], iterable: _ToKeys1D, v: opt.JustInt = 1, /) -> _dok_base[np.int_, _1D]: ...
226+
@overload
227+
@classmethod
228+
def fromkeys(
229+
cls: type[_dok_base[np.float64, _2D]],
230+
iterable: _ToKeys2D,
231+
v: opt.JustFloat,
232+
/,
233+
) -> _dok_base[np.float64, _2D]: ...
234+
@overload
235+
@classmethod
236+
def fromkeys(
237+
cls: type[_dok_base[np.float64, _1D]],
238+
iterable: _ToKeys1D,
239+
v: opt.JustFloat,
240+
/,
241+
) -> _dok_base[np.float64, _1D]: ...
242+
@overload
243+
@classmethod
244+
def fromkeys(
245+
cls: type[_dok_base[np.complex128, _2D]],
246+
iterable: _ToKeys2D,
247+
v: opt.JustComplex,
248+
/,
249+
) -> _dok_base[np.complex128, _2D]: ...
250+
@overload
251+
@classmethod
252+
def fromkeys(
253+
cls: type[_dok_base[np.complex128, _1D]],
254+
iterable: _ToKeys1D,
255+
v: opt.JustComplex,
256+
/,
257+
) -> _dok_base[np.complex128, _1D]: ...
160258

161-
class dok_matrix(_dok_base[_SCT, tuple[int, int]], spmatrix[_SCT], Generic[_SCT]):
259+
#
260+
class dok_array(_dok_base[_SCT, _ShapeT_co], sparray, Generic[_SCT, _ShapeT_co]):
261+
# NOTE: This horrible code duplication is required due to the lack of higher-kinded typing (HKT) support.
262+
# https://github.com/python/typing/issues/548
263+
@overload
264+
@classmethod
265+
def fromkeys(cls: type[dok_array[_SCT, _2D]], iterable: _ToKeys2D, v: _SCT, /) -> dok_array[_SCT, _2D]: ...
266+
@overload
267+
@classmethod
268+
def fromkeys(cls: type[dok_array[_SCT, _1D]], iterable: _ToKeys1D, v: _SCT, /) -> dok_array[_SCT, _1D]: ...
269+
@overload
270+
@classmethod
271+
def fromkeys(cls: type[dok_array[np.bool_, _2D]], iterable: _ToKeys2D, v: onp.ToBool, /) -> dok_array[np.bool_, _2D]: ...
272+
@overload
273+
@classmethod
274+
def fromkeys(cls: type[dok_array[np.bool_, _1D]], iterable: _ToKeys1D, v: onp.ToBool, /) -> dok_array[np.bool_, _1D]: ...
275+
@overload
276+
@classmethod
277+
def fromkeys(cls: type[dok_array[np.int_, _2D]], iterable: _ToKeys2D, v: opt.JustInt = 1, /) -> dok_array[np.int_, _2D]: ...
278+
@overload
279+
@classmethod
280+
def fromkeys(cls: type[dok_array[np.int_, _1D]], iterable: _ToKeys1D, v: opt.JustInt = 1, /) -> dok_array[np.int_, _1D]: ...
281+
@overload
282+
@classmethod
283+
def fromkeys(
284+
cls: type[dok_array[np.float64, _2D]],
285+
iterable: _ToKeys2D,
286+
v: opt.JustFloat,
287+
/,
288+
) -> dok_array[np.float64, _2D]: ...
289+
@overload
290+
@classmethod
291+
def fromkeys(
292+
cls: type[dok_array[np.float64, _1D]],
293+
iterable: _ToKeys1D,
294+
v: opt.JustFloat,
295+
/,
296+
) -> dok_array[np.float64, _1D]: ...
297+
@overload
298+
@classmethod
299+
def fromkeys(
300+
cls: type[dok_array[np.complex128, _2D]],
301+
iterable: _ToKeys2D,
302+
v: opt.JustComplex,
303+
/,
304+
) -> dok_array[np.complex128, _2D]: ...
305+
@overload
306+
@classmethod
307+
def fromkeys(
308+
cls: type[dok_array[np.complex128, _1D]],
309+
iterable: _ToKeys1D,
310+
v: opt.JustComplex,
311+
/,
312+
) -> dok_array[np.complex128, _1D]: ...
313+
314+
#
315+
class dok_matrix(_dok_base[_SCT, _2D], spmatrix[_SCT], Generic[_SCT]):
162316
@override
163-
def get(self, key: tuple[onp.ToJustInt, onp.ToJustInt], /, default: onp.ToComplex = 0.0) -> _SCT: ...
317+
def get(self, /, key: _ToKey2D, default: onp.ToComplex = 0.0) -> _SCT: ...
164318
@override
165-
def setdefault(self, key: tuple[onp.ToJustInt, onp.ToJustInt], default: onp.ToComplex | None = None, /) -> _SCT: ...
319+
def setdefault(self, key: _ToKey2D, default: onp.ToComplex | None = None, /) -> _SCT: ...
320+
321+
#
322+
@overload
323+
@classmethod
324+
def fromkeys(cls, iterable: _ToKeys2D, v: _SCT, /) -> Self: ...
325+
@overload
326+
@classmethod
327+
def fromkeys(cls: type[dok_matrix[np.bool_]], iterable: _ToKeys2D, v: onp.ToBool, /) -> dok_matrix[np.bool_]: ...
328+
@overload
329+
@classmethod
330+
def fromkeys(cls: type[dok_matrix[np.int_]], iterable: _ToKeys2D, v: opt.JustInt = 1, /) -> dok_matrix[np.int_]: ...
331+
@overload
332+
@classmethod
333+
def fromkeys(cls: type[dok_matrix[np.float64]], iterable: _ToKeys2D, v: opt.JustFloat, /) -> dok_matrix[np.float64]: ...
334+
@overload
335+
@classmethod
336+
def fromkeys(
337+
cls: type[dok_matrix[np.complex128]],
338+
iterable: _ToKeys2D,
339+
v: opt.JustComplex,
340+
/,
341+
) -> dok_matrix[np.complex128]: ...
166342

343+
#
167344
def isspmatrix_dok(x: object) -> TypeIs[dok_matrix]: ...

0 commit comments

Comments
 (0)