Skip to content

Commit 404f5c1

Browse files
authored
🐛 sparse.linalg: Fix factorize return type annotations (#679)
2 parents 1e6c429 + f9a122f commit 404f5c1

File tree

1 file changed

+24
-22
lines changed

1 file changed

+24
-22
lines changed

scipy-stubs/sparse/linalg/_dsolve/linsolve.pyi

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections.abc import Mapping
2-
from typing import Any, Literal, Protocol, TypeAlias, TypeVar, overload, type_check_only
3-
from typing_extensions import deprecated
2+
from typing import Any, Literal, Protocol, TypeAlias, overload, type_check_only
3+
from typing_extensions import TypeVar, deprecated
44

55
import numpy as np
66
import optype.numpy as onp
@@ -24,12 +24,10 @@ __all__ = [
2424
]
2525

2626
_SparseT = TypeVar("_SparseT", bound=_spbase)
27+
_NumberT_contra = TypeVar("_NumberT_contra", bound=npc.number, contravariant=True)
28+
_InexactT_co = TypeVar("_InexactT_co", bound=np.float32 | np.float64 | np.complex64 | np.complex128, covariant=True)
2729

2830
_PermcSpec: TypeAlias = Literal["COLAMD", "NATURAL", "MMD_ATA", "MMD_AT_PLUS_A"]
29-
_Float1D: TypeAlias = onp.Array1D[np.float64]
30-
_Float2D: TypeAlias = onp.Array2D[np.float64]
31-
_Complex1D: TypeAlias = onp.Array1D[np.complex128]
32-
_Complex2D: TypeAlias = onp.Array2D[np.complex128]
3331

3432
_ToF32Mat: TypeAlias = _spbase[np.float32, tuple[int, int]] | onp.CanArray[tuple[Any, ...], np.dtype[np.float32]]
3533
_ToF64Mat: TypeAlias = _spbase[np.float64 | npc.integer, tuple[int, int]] | onp.ToInt2D | onp.ToJustFloat64_2D
@@ -42,31 +40,35 @@ _ToComplexMat: TypeAlias = _spbase[npc.complexfloating, tuple[int, int]] | onp.T
4240
_ToInexactMat: TypeAlias = _spbase[Any, tuple[int, int]] | onp.ToComplex128_2D
4341
_ToInexactMatStrict: TypeAlias = _spbase[Any, tuple[int, int]] | onp.ToComplex128Strict2D
4442

45-
# TODO(jorenham): make generic (because safe casting rules apply, i.e. the current annotations are incorrect)
46-
# https://github.com/scipy/scipy-stubs/issues/677
43+
_AsF32: TypeAlias = npc.integer8 | npc.number16 | np.int32 | np.float16 | np.float32
44+
_AsF64: TypeAlias = npc.integer | np.float16 | np.float32 | np.float64
45+
_AsC64: TypeAlias = npc.integer8 | npc.integer16 | np.int32 | np.float16 | npc.inexact32
46+
_AsC128: TypeAlias = npc.integer | np.float16 | npc.inexact32 | npc.inexact64
47+
4748
@type_check_only
48-
class _Solve(Protocol):
49-
@overload
50-
def __call__(self, b: onp.Array1D[npc.integer | npc.floating], /) -> _Float1D: ...
51-
@overload
52-
def __call__(self, b: onp.Array1D[npc.complexfloating], /) -> _Complex1D: ...
49+
class _SuperLU_solve(Protocol[_NumberT_contra, _InexactT_co]):
5350
@overload
54-
def __call__(self, b: onp.Array2D[npc.integer | npc.floating], /) -> _Float2D: ...
51+
def __call__(self, rhs: onp.Array1D[_NumberT_contra | np.bool_]) -> onp.Array1D[_InexactT_co]: ...
5552
@overload
56-
def __call__(self, b: onp.Array2D[npc.complexfloating], /) -> _Complex2D: ...
53+
def __call__(self, rhs: onp.Array2D[_NumberT_contra | np.bool_]) -> onp.Array2D[_InexactT_co]: ...
5754
@overload
58-
def __call__(self, b: onp.ArrayND[npc.integer | npc.floating], /) -> _Float1D | _Float2D: ...
59-
@overload
60-
def __call__(self, b: onp.ArrayND[npc.complexfloating], /) -> _Complex1D | _Complex2D: ...
61-
@overload
62-
def __call__(self, b: onp.ArrayND[npc.number], /) -> _Float1D | _Complex1D | _Float2D | _Complex2D: ...
55+
def __call__(self, rhs: onp.ArrayND[_NumberT_contra | np.bool_]) -> onp.ArrayND[_InexactT_co]: ...
6356

6457
###
6558

6659
class MatrixRankWarning(UserWarning): ...
6760

68-
def use_solver(*, useUmfpack: bool = ..., assumeSortedIndices: bool = ...) -> None: ...
69-
def factorized(A: _ToInexactMat) -> _Solve: ...
61+
def use_solver(*, useUmfpack: bool = True, assumeSortedIndices: bool = False) -> None: ...
62+
63+
# NOTE: the mypy ignores work around a mypy bug in overload overlap checking
64+
@overload
65+
def factorized(A: _ToF64Mat) -> _SuperLU_solve[_AsF64, np.float64]: ... # type: ignore[overload-overlap]
66+
@overload
67+
def factorized(A: _ToC128Mat) -> _SuperLU_solve[_AsC128, np.complex128]: ... # type: ignore[overload-overlap]
68+
@overload
69+
def factorized(A: _ToF32Mat) -> _SuperLU_solve[_AsF32, np.float32]: ...
70+
@overload
71+
def factorized(A: _ToC64Mat) -> _SuperLU_solve[_AsC64, np.complex64]: ...
7072

7173
#
7274
@overload # 2d float, sparse 2d

0 commit comments

Comments
 (0)