Skip to content

Commit 241509b

Browse files
committed
linalg: improved inv annotations
1 parent 373827a commit 241509b

File tree

1 file changed

+45
-11
lines changed

1 file changed

+45
-11
lines changed

scipy-stubs/linalg/_basic.pyi

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mypy: disable-error-code=overload-overlap
22

3+
from collections.abc import Sequence
34
from typing import Final, Literal, TypeAlias, TypeVar, overload
45

56
import numpy as np
@@ -23,20 +24,20 @@ __all__ = [
2324
"solveh_banded",
2425
]
2526

27+
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
2628
_T = TypeVar("_T")
29+
2730
_Tuple2: TypeAlias = tuple[_T, _T]
2831
_COrCR: TypeAlias = _T | _Tuple2[_T]
2932

3033
_Float: TypeAlias = npc.floating
3134
_Float0D: TypeAlias = onp.Array0D[_Float]
3235
_Float1D: TypeAlias = onp.Array1D[_Float]
33-
_Float2D: TypeAlias = onp.Array2D[_Float]
3436
_FloatND: TypeAlias = onp.ArrayND[_Float]
3537

3638
_Inexact: TypeAlias = npc.inexact
3739
_Inexact0D: TypeAlias = onp.Array0D[_Inexact]
3840
_Inexact1D: TypeAlias = onp.Array1D[_Inexact]
39-
_Inexact2D: TypeAlias = onp.Array2D[_Inexact]
4041
_InexactND: TypeAlias = onp.ArrayND[_Inexact]
4142

4243
_InputFloat: TypeAlias = onp.ToArrayND[float, np.float64 | np.longdouble | npc.integer | np.bool_]
@@ -992,15 +993,48 @@ def solve_circulant(
992993
outaxis: int = 0,
993994
) -> onp.ArrayND[npc.inexact]: ...
994995

995-
# TODO(jorenham): improve this
996-
@overload # floating 2d
997-
def inv(a: onp.ToFloatStrict2D, overwrite_a: bool = False, check_finite: bool = True) -> _Float2D: ...
998-
@overload # floating
999-
def inv(a: onp.ToFloatND, overwrite_a: bool = False, check_finite: bool = True) -> _FloatND: ...
1000-
@overload # complexfloating 2d
1001-
def inv(a: onp.ToComplexStrict2D, overwrite_a: bool = False, check_finite: bool = True) -> _Inexact2D: ...
1002-
@overload # complexfloating
1003-
def inv(a: onp.ToComplexND, overwrite_a: bool = False, check_finite: bool = True) -> _InexactND: ...
996+
#
997+
998+
@overload # 2d bool sequence
999+
def inv(a: Sequence[Sequence[bool]], overwrite_a: bool = False, check_finite: bool = True) -> onp.Array2D[np.float32]: ...
1000+
@overload # Nd bool sequence
1001+
def inv(a: Sequence[onp.SequenceND[bool]], overwrite_a: bool = False, check_finite: bool = True) -> onp.ArrayND[np.float32]: ...
1002+
@overload # 2d float or int sequence
1003+
def inv(
1004+
a: Sequence[Sequence[op.JustFloat | op.JustInt]], overwrite_a: bool = False, check_finite: bool = True
1005+
) -> onp.Array2D[np.float64]: ...
1006+
@overload # Nd float or int sequence
1007+
def inv(
1008+
a: Sequence[onp.SequenceND[op.JustFloat | op.JustInt]], overwrite_a: bool = False, check_finite: bool = True
1009+
) -> onp.ArrayND[np.float64]: ...
1010+
@overload # 2d complex sequence
1011+
def inv(
1012+
a: Sequence[Sequence[op.JustComplex]], overwrite_a: bool = False, check_finite: bool = True
1013+
) -> onp.Array2D[np.complex128]: ...
1014+
@overload # Nd complex sequence
1015+
def inv(
1016+
a: Sequence[onp.SequenceND[op.JustComplex]], overwrite_a: bool = False, check_finite: bool = True
1017+
) -> onp.ArrayND[np.complex128]: ...
1018+
@overload # generic shape, as float32
1019+
def inv(
1020+
a: onp.CanArrayND[np.float32 | npc.number16 | npc.integer8 | np.bool_, _ShapeT],
1021+
overwrite_a: bool = False,
1022+
check_finite: bool = True,
1023+
) -> onp.ArrayND[np.float32, _ShapeT]: ...
1024+
@overload # generic shape, as float64
1025+
def inv(
1026+
a: onp.CanArrayND[np.float64 | np.longdouble | npc.integer64 | npc.integer32, _ShapeT],
1027+
overwrite_a: bool = False,
1028+
check_finite: bool = True,
1029+
) -> onp.ArrayND[np.float64, _ShapeT]: ...
1030+
@overload # generic shape, as complex64
1031+
def inv(
1032+
a: onp.CanArrayND[np.complex64, _ShapeT], overwrite_a: bool = False, check_finite: bool = True
1033+
) -> onp.ArrayND[np.complex64, _ShapeT]: ...
1034+
@overload # generic shape, as complex128
1035+
def inv(
1036+
a: onp.CanArrayND[np.complex128 | np.clongdouble, _ShapeT], overwrite_a: bool = False, check_finite: bool = True
1037+
) -> onp.ArrayND[np.complex128, _ShapeT]: ...
10041038

10051039
# TODO(jorenham): improve this
10061040
@overload # floating 2d

0 commit comments

Comments
 (0)