Skip to content

Commit 4d5916e

Browse files
authored
✨ improve the shape- & scalar-type overloads of various linalg functions (#252)
1 parent fe4a14f commit 4d5916e

File tree

8 files changed

+1891
-1724
lines changed

8 files changed

+1891
-1724
lines changed

scipy-stubs/linalg/_matfuncs.pyi

Lines changed: 105 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ from collections.abc import Callable
22
from typing import Any, Literal, TypeAlias, overload
33

44
import numpy as np
5-
import numpy.typing as npt
65
import optype.numpy as onp
6+
import optype.typing as opt
77
from ._expm_frechet import expm_cond, expm_frechet
88
from ._matfuncs_sqrtm import sqrtm
99

@@ -25,35 +25,107 @@ __all__ = [
2525
"tanm",
2626
]
2727

28-
_Array_fc_2d: TypeAlias = onp.Array2D[np.inexact[Any]]
29-
_Array_fc_nd: TypeAlias = onp.ArrayND[np.inexact[Any]]
30-
31-
def fractional_matrix_power(A: npt.ArrayLike, t: float) -> _Array_fc_2d: ...
32-
@overload
33-
def logm(A: npt.ArrayLike, disp: Literal[True] = True) -> _Array_fc_2d: ...
34-
@overload
35-
def logm(A: npt.ArrayLike, disp: Literal[False]) -> tuple[_Array_fc_2d, float | np.float64]: ...
36-
def expm(A: npt.ArrayLike) -> onp.ArrayND[np.inexact[Any]]: ...
37-
def cosm(A: npt.ArrayLike) -> _Array_fc_2d: ...
38-
def sinm(A: npt.ArrayLike) -> _Array_fc_2d: ...
39-
def tanm(A: npt.ArrayLike) -> _Array_fc_2d: ...
40-
def coshm(A: npt.ArrayLike) -> _Array_fc_2d: ...
41-
def sinhm(A: npt.ArrayLike) -> _Array_fc_2d: ...
42-
def tanhm(A: npt.ArrayLike) -> _Array_fc_2d: ...
43-
@overload
44-
def funm(
45-
A: npt.ArrayLike,
46-
func: Callable[[_Array_fc_nd], _Array_fc_nd],
47-
disp: Literal[True, 1] = True,
48-
) -> _Array_fc_2d: ...
49-
@overload
50-
def funm(
51-
A: npt.ArrayLike,
52-
func: Callable[[_Array_fc_nd], _Array_fc_nd],
53-
disp: Literal[False, 0],
54-
) -> tuple[_Array_fc_2d, float | np.float64]: ...
55-
@overload
56-
def signm(A: npt.ArrayLike, disp: Literal[True] = True) -> _Array_fc_2d: ...
57-
@overload
58-
def signm(A: npt.ArrayLike, disp: Literal[False]) -> tuple[_Array_fc_2d, float | np.float64]: ...
59-
def khatri_rao(a: npt.ArrayLike, b: npt.ArrayLike) -> _Array_fc_2d: ...
28+
_ToPosInt: TypeAlias = np.unsignedinteger[Any] | Literal[0, 1, 2, 4, 5, 6, 7, 8]
29+
30+
_Int2D: TypeAlias = onp.Array2D[np.integer[Any]]
31+
_Complex2D: TypeAlias = onp.Array2D[np.complexfloating[Any, Any]]
32+
_Real2D: TypeAlias = onp.Array2D[np.floating[Any] | np.integer[Any]]
33+
_Numeric2D: TypeAlias = onp.Array2D[np.number[Any]]
34+
_Float2D: TypeAlias = onp.Array2D[np.floating[Any]]
35+
_Inexact2D: TypeAlias = onp.Array2D[np.inexact[Any]]
36+
37+
_FloatND: TypeAlias = onp.ArrayND[np.floating[Any]]
38+
_InexactND: TypeAlias = onp.ArrayND[np.inexact[Any]]
39+
40+
_Falsy: TypeAlias = Literal[False, 0]
41+
_Truthy: TypeAlias = Literal[True, 1]
42+
43+
_FloatFunc: TypeAlias = Callable[[onp.Array1D[np.float64]], onp.ToFloat1D]
44+
_ComplexFunc: TypeAlias = Callable[[onp.Array1D[np.complex128]], onp.ToComplex1D]
45+
46+
###
47+
48+
@overload # int, positive int
49+
def fractional_matrix_power(A: onp.ToInt2D, t: _ToPosInt) -> _Int2D: ... # pyright: ignore[reportOverlappingOverload] # pyright<1.390 bug
50+
@overload # real, int
51+
def fractional_matrix_power(A: onp.ToFloat2D, t: onp.ToInt) -> _Real2D: ...
52+
@overload # complex, int
53+
def fractional_matrix_power(A: onp.ToComplex2D, t: onp.ToInt) -> _Numeric2D: ...
54+
@overload # complex, float
55+
def fractional_matrix_power(A: onp.ToComplex2D, t: opt.Just[float] | np.floating[Any]) -> _Complex2D: ...
56+
57+
# NOTE: return dtype depends on the sign of the values
58+
@overload # disp: True = ...
59+
def logm(A: onp.ToComplex2D, disp: _Truthy = True) -> _Inexact2D: ...
60+
@overload # disp: False
61+
def logm(A: onp.ToComplex2D, disp: _Falsy) -> tuple[_Inexact2D, float]: ...
62+
63+
#
64+
@overload # real
65+
def expm(A: onp.ToFloatND) -> _FloatND: ...
66+
@overload # complex
67+
def expm(A: onp.ToComplexND) -> _InexactND: ...
68+
69+
#
70+
@overload # real
71+
def cosm(A: onp.ToFloat2D) -> _Float2D: ...
72+
@overload # complex
73+
def cosm(A: onp.ToComplex2D) -> _Inexact2D: ...
74+
75+
#
76+
@overload # real
77+
def sinm(A: onp.ToFloat2D) -> _Float2D: ...
78+
@overload # complex
79+
def sinm(A: onp.ToComplex2D) -> _Inexact2D: ...
80+
81+
#
82+
@overload # real
83+
def tanm(A: onp.ToFloat2D) -> _Float2D: ...
84+
@overload # complex
85+
def tanm(A: onp.ToComplex2D) -> _Inexact2D: ...
86+
87+
#
88+
@overload # real
89+
def coshm(A: onp.ToFloat2D) -> _Float2D: ...
90+
@overload # complex
91+
def coshm(A: onp.ToComplex2D) -> _Inexact2D: ...
92+
93+
#
94+
@overload # real
95+
def sinhm(A: onp.ToFloat2D) -> _Float2D: ...
96+
@overload # complex
97+
def sinhm(A: onp.ToComplex2D) -> _Inexact2D: ...
98+
99+
#
100+
@overload # real
101+
def tanhm(A: onp.ToFloat2D) -> _Float2D: ...
102+
@overload # complex
103+
def tanhm(A: onp.ToComplex2D) -> _Inexact2D: ...
104+
105+
#
106+
@overload # real, disp: True = ...
107+
def funm(A: onp.ToFloat2D, func: _FloatFunc, disp: _Truthy = True) -> _Float2D: ...
108+
@overload # real, disp: False
109+
def funm(A: onp.ToFloat2D, func: _FloatFunc, disp: _Falsy) -> _Complex2D: ...
110+
@overload # complex, disp: True = ...
111+
def funm(A: onp.ToComplex2D, func: _ComplexFunc, disp: _Truthy = True) -> _Complex2D: ...
112+
@overload # complex, disp: False
113+
def funm(A: onp.ToComplex2D, func: _ComplexFunc, disp: _Falsy) -> tuple[_Complex2D, np.float64]: ...
114+
115+
#
116+
@overload # real, disp: True = ...
117+
def signm(A: onp.ToFloat2D, disp: _Truthy = True) -> _Float2D: ...
118+
@overload # real, disp: False
119+
def signm(A: onp.ToFloat2D, disp: _Falsy) -> tuple[_Float2D, np.float64]: ...
120+
@overload # complex, disp: True = ...
121+
def signm(A: onp.ToComplex2D, disp: _Truthy = True) -> _Inexact2D: ...
122+
@overload # complex, disp: False
123+
def signm(A: onp.ToComplex2D, disp: _Falsy) -> tuple[_Inexact2D, np.float64]: ...
124+
125+
#
126+
@overload # int
127+
def khatri_rao(a: onp.ToInt2D, b: onp.ToInt2D) -> _Int2D: ...
128+
@overload # real
129+
def khatri_rao(a: onp.ToFloat2D, b: onp.ToFloat2D) -> _Real2D: ...
130+
@overload # complex
131+
def khatri_rao(a: onp.ToComplex2D, b: onp.ToComplex2D) -> _Numeric2D: ...
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1+
from typing import TypeAlias
2+
3+
import numpy as np
14
import optype.numpy as onp
25

36
__all__ = ["pade_UV_calc", "pick_pade_structure"]
47

5-
def pick_pade_structure(Am: onp.ArrayND) -> tuple[int, int]: ...
6-
def pade_UV_calc(Am: onp.ArrayND, n: int, m: int) -> None: ...
8+
_Inexact: TypeAlias = np.float32 | np.float64 | np.complex64 | np.complex128
9+
10+
# `Am` must have a shape like `(5, n, n)`.
11+
def pick_pade_structure(Am: onp.ArrayND[_Inexact]) -> tuple[int, int]: ...
12+
def pade_UV_calc(Am: onp.ArrayND[_Inexact], n: int, m: int) -> None: ...
Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
from typing import Any, Literal, TypeAlias, overload
22

33
import numpy as np
4-
import numpy.typing as npt
54
import optype.numpy as onp
65

76
__all__ = ["sqrtm"]
87

9-
_Array_fc_2d: TypeAlias = onp.Array2D[np.inexact[Any]]
8+
_Inexact2D: TypeAlias = onp.Array2D[np.floating[Any]] | onp.Array2D[np.complexfloating[Any, Any]]
109

11-
class SqrtmError(np.linalg.LinAlgError): ...
10+
_Falsy: TypeAlias = Literal[False, 0]
11+
_Truthy: TypeAlias = Literal[True, 1]
1212

13+
###
14+
15+
class SqrtmError(np.linalg.LinAlgError): ... # undocumented
16+
17+
# NOTE: The output dtype (floating or complex) depends on the sign of the values, so this is the best we can do.
1318
@overload
14-
def sqrtm(A: npt.ArrayLike, disp: Literal[True] = True, blocksize: int = 64) -> _Array_fc_2d: ...
19+
def sqrtm(A: onp.ToComplex2D, disp: _Truthy = True, blocksize: onp.ToJustInt = 64) -> _Inexact2D: ...
1520
@overload
16-
def sqrtm(A: npt.ArrayLike, disp: Literal[False], blocksize: int = 64) -> tuple[_Array_fc_2d, float | np.float64]: ...
21+
def sqrtm(A: onp.ToComplex2D, disp: _Falsy, blocksize: onp.ToJustInt = 64) -> tuple[_Inexact2D, np.floating[Any]]: ...

scipy-stubs/linalg/_procrustes.pyi

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
1-
from typing import Any
1+
from typing import TypeAlias, overload
22

33
import numpy as np
4-
import numpy.typing as npt
4+
import optype as op
55
import optype.numpy as onp
66

77
__all__ = ["orthogonal_procrustes"]
88

9+
_Float: TypeAlias = np.float32 | np.float64
10+
_Complex: TypeAlias = np.complex64 | np.complex128
11+
12+
###
13+
14+
@overload
15+
def orthogonal_procrustes(
16+
A: onp.ToFloat2D,
17+
B: onp.ToFloat2D,
18+
check_finite: op.CanBool = True,
19+
) -> tuple[onp.Array2D[_Float], _Float]: ...
20+
@overload
921
def orthogonal_procrustes(
10-
A: npt.ArrayLike,
11-
B: npt.ArrayLike,
12-
check_finite: bool = True,
13-
) -> tuple[onp.Array2D[np.inexact[Any]], float | np.float64]: ...
22+
A: onp.ToComplex2D,
23+
B: onp.ToComplex2D,
24+
check_finite: onp.ToBool = True,
25+
) -> tuple[onp.Array2D[_Float | _Complex], _Float]: ...

scipy-stubs/linalg/_sketches.pyi

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,38 @@
1+
from collections.abc import Sequence
2+
from typing import Any, TypeAlias, TypeVar, overload
3+
14
import numpy as np
2-
import numpy.typing as npt
35
import optype.numpy as onp
6+
import optype.typing as opt
47
import scipy._typing as spt
58
from scipy.sparse import csc_matrix
69

710
__all__ = ["clarkson_woodruff_transform"]
811

9-
def cwt_matrix(n_rows: int, n_columns: int, seed: spt.Seed | None = None) -> csc_matrix: ...
12+
_ST = TypeVar("_ST", bound=np.generic)
13+
_VT = TypeVar("_VT")
14+
_ToJust2D: TypeAlias = onp.CanArrayND[_ST] | Sequence[onp.CanArrayND[_ST]] | Sequence[Sequence[opt.Just[_VT] | _ST]]
15+
16+
###
17+
18+
def cwt_matrix(n_rows: onp.ToInt, n_columns: onp.ToInt, seed: spt.Seed | None = None) -> csc_matrix: ...
19+
20+
#
21+
@overload
22+
def clarkson_woodruff_transform(
23+
input_matrix: onp.ToInt2D,
24+
sketch_size: onp.ToInt,
25+
seed: spt.Seed | None = None,
26+
) -> onp.Array2D[np.int_]: ...
27+
@overload
28+
def clarkson_woodruff_transform(
29+
input_matrix: _ToJust2D[np.floating[Any], float],
30+
sketch_size: onp.ToInt,
31+
seed: spt.Seed | None = None,
32+
) -> onp.Array2D[np.float64 | np.longdouble]: ...
33+
@overload
1034
def clarkson_woodruff_transform(
11-
input_matrix: npt.ArrayLike,
12-
sketch_size: int,
35+
input_matrix: _ToJust2D[np.complexfloating[Any, Any], complex],
36+
sketch_size: onp.ToInt,
1337
seed: spt.Seed | None = None,
14-
) -> onp.ArrayND[np.float64]: ...
38+
) -> onp.Array2D[np.complex64 | np.clongdouble]: ...

scipy-stubs/linalg/_solvers.pyi

Lines changed: 75 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import Any, Literal
1+
from typing import Final, Literal, TypeAlias, overload
22

33
import numpy as np
4-
import numpy.typing as npt
4+
import optype as op
55
import optype.numpy as onp
66

77
__all__ = [
@@ -13,31 +13,81 @@ __all__ = [
1313
"solve_sylvester",
1414
]
1515

16-
def solve_sylvester(a: npt.ArrayLike, b: npt.ArrayLike, q: npt.ArrayLike) -> onp.Array2D[np.inexact[Any]]: ...
17-
def solve_continuous_lyapunov(a: npt.ArrayLike, q: npt.ArrayLike) -> onp.Array2D[np.inexact[Any]]: ...
16+
_Float: TypeAlias = np.float32 | np.float64
17+
_Complex: TypeAlias = np.complex64 | np.complex128
1818

19-
solve_lyapunov = solve_continuous_lyapunov
19+
_DiscreteMethod: TypeAlias = Literal["direct", "bilinear"]
2020

21+
###
22+
23+
@overload # real
24+
def solve_sylvester(a: onp.ToFloat2D, b: onp.ToFloat2D, q: onp.ToFloat2D) -> onp.Array2D[_Float]: ...
25+
@overload # complex
26+
def solve_sylvester(a: onp.ToComplex2D, b: onp.ToComplex2D, q: onp.ToComplex2D) -> onp.Array2D[_Float | _Complex]: ...
27+
28+
#
29+
@overload # real
30+
def solve_continuous_lyapunov(a: onp.ToFloat2D, q: onp.ToFloat2D) -> onp.Array2D[_Float]: ...
31+
@overload # complex
32+
def solve_continuous_lyapunov(a: onp.ToComplex2D, q: onp.ToComplex2D) -> onp.Array2D[_Float | _Complex]: ...
33+
34+
#
35+
solve_lyapunov: Final = solve_continuous_lyapunov
36+
37+
#
38+
@overload # real
39+
def solve_discrete_lyapunov(
40+
a: onp.ToFloat2D,
41+
q: onp.ToFloat2D,
42+
method: _DiscreteMethod | None = None,
43+
) -> onp.Array2D[_Float]: ...
44+
@overload # complex
2145
def solve_discrete_lyapunov(
22-
a: npt.ArrayLike,
23-
q: npt.ArrayLike,
24-
method: Literal["direct", "bilinear"] | None = None,
25-
) -> onp.Array2D[np.inexact[Any]]: ...
46+
a: onp.ToComplex2D,
47+
q: onp.ToComplex2D,
48+
method: _DiscreteMethod | None = None,
49+
) -> onp.Array2D[_Float | _Complex]: ...
50+
51+
#
52+
@overload # real
53+
def solve_continuous_are(
54+
a: onp.ToFloat2D,
55+
b: onp.ToFloat2D,
56+
q: onp.ToFloat2D,
57+
r: onp.ToFloat2D,
58+
e: onp.ToFloat2D | None = None,
59+
s: onp.ToFloat2D | None = None,
60+
balanced: op.CanBool = True,
61+
) -> onp.Array2D[_Float]: ...
62+
@overload # complex
2663
def solve_continuous_are(
27-
a: npt.ArrayLike,
28-
b: npt.ArrayLike,
29-
q: npt.ArrayLike,
30-
r: npt.ArrayLike,
31-
e: npt.ArrayLike | None = None,
32-
s: npt.ArrayLike | None = None,
33-
balanced: bool = True,
34-
) -> onp.Array2D[np.inexact[Any]]: ...
64+
a: onp.ToComplex2D,
65+
b: onp.ToComplex2D,
66+
q: onp.ToComplex2D,
67+
r: onp.ToComplex2D,
68+
e: onp.ToComplex2D | None = None,
69+
s: onp.ToComplex2D | None = None,
70+
balanced: op.CanBool = True,
71+
) -> onp.Array2D[_Float | _Complex]: ...
72+
73+
#
74+
@overload # real
75+
def solve_discrete_are(
76+
a: onp.ToFloat2D,
77+
b: onp.ToFloat2D,
78+
q: onp.ToFloat2D,
79+
r: onp.ToFloat2D,
80+
e: onp.ToFloat2D | None = None,
81+
s: onp.ToFloat2D | None = None,
82+
balanced: op.CanBool = True,
83+
) -> onp.Array2D[_Float]: ...
84+
@overload # complex
3585
def solve_discrete_are(
36-
a: npt.ArrayLike,
37-
b: npt.ArrayLike,
38-
q: npt.ArrayLike,
39-
r: npt.ArrayLike,
40-
e: npt.ArrayLike | None = None,
41-
s: npt.ArrayLike | None = None,
42-
balanced: bool = True,
43-
) -> onp.Array2D[np.inexact[Any]]: ...
86+
a: onp.ToComplex2D,
87+
b: onp.ToComplex2D,
88+
q: onp.ToComplex2D,
89+
r: onp.ToComplex2D,
90+
e: onp.ToComplex2D | None = None,
91+
s: onp.ToComplex2D | None = None,
92+
balanced: op.CanBool = True,
93+
) -> onp.Array2D[_Float | _Complex]: ...

0 commit comments

Comments
 (0)