Skip to content

Commit dbed9be

Browse files
authored
✨ improve linalg._special_matrices with shape and dtype overloads (#246)
1 parent 8e9ecde commit dbed9be

File tree

1 file changed

+120
-14
lines changed

1 file changed

+120
-14
lines changed

scipy-stubs/linalg/_special_matrices.pyi

Lines changed: 120 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from typing import Literal, TypeAlias, overload
1+
from collections.abc import Sequence
2+
from typing import Any, Literal, TypeAlias, overload
23
from typing_extensions import TypeVar
34

45
import numpy as np
56
import numpy.typing as npt
67
import optype.numpy as onp
8+
import optype.typing as opt
79
from scipy._typing import CorrelateMode
810

911
__all__ = [
@@ -26,43 +28,147 @@ __all__ = [
2628
"toeplitz",
2729
]
2830

29-
_SCT = TypeVar("_SCT", bound=np.generic, default=np.generic)
31+
_SCT = TypeVar("_SCT", bound=np.generic, default=np.number[Any] | np.bool_ | np.object_)
3032

3133
_Matrix: TypeAlias = onp.Array2D[_SCT]
3234
_Kind: TypeAlias = Literal["symmetric", "upper", "lower"]
3335

3436
###
3537

36-
# TODO(jorenham): transparent dtypes
37-
def toeplitz(c: npt.ArrayLike, r: npt.ArrayLike | None = None) -> _Matrix: ...
38-
def circulant(c: npt.ArrayLike) -> _Matrix: ...
39-
def hankel(c: npt.ArrayLike, r: npt.ArrayLike | None = None) -> _Matrix: ...
40-
def hadamard(n: onp.ToInt, dtype: npt.DTypeLike = ...) -> _Matrix: ...
41-
def leslie(f: npt.ArrayLike, s: npt.ArrayLike) -> _Matrix: ...
42-
def kron(a: npt.ArrayLike, b: npt.ArrayLike) -> _Matrix: ...
43-
def block_diag(*arrs: npt.ArrayLike) -> _Matrix: ...
44-
def companion(a: npt.ArrayLike) -> _Matrix: ...
38+
#
39+
@overload
40+
def toeplitz(c: Sequence[opt.JustInt], r: Sequence[opt.JustInt] | None = None) -> _Matrix[np.int_]: ...
41+
@overload
42+
def toeplitz(c: Sequence[opt.Just[float]], r: Sequence[opt.Just[float]] | None = None) -> _Matrix[np.float64]: ...
43+
@overload
44+
def toeplitz(c: Sequence[opt.Just[complex]], r: Sequence[opt.Just[complex]] | None = None) -> _Matrix[np.complex128]: ...
45+
@overload
46+
def toeplitz(c: onp.CanArrayND[_SCT] | Sequence[_SCT], r: onp.CanArrayND[_SCT] | None = None) -> _Matrix[_SCT]: ...
47+
48+
#
49+
@overload
50+
def circulant(c: Sequence[opt.JustInt]) -> _Matrix[np.int_]: ...
51+
@overload
52+
def circulant(c: Sequence[opt.Just[float]]) -> _Matrix[np.float64]: ...
53+
@overload
54+
def circulant(c: Sequence[opt.Just[complex]]) -> _Matrix[np.complex128]: ...
55+
@overload
56+
def circulant(c: onp.CanArrayND[_SCT] | Sequence[_SCT]) -> _Matrix[_SCT]: ...
57+
58+
#
59+
@overload
60+
def hankel(c: Sequence[opt.JustInt], r: Sequence[opt.JustInt] | None = None) -> _Matrix[np.int_]: ...
61+
@overload
62+
def hankel(c: Sequence[opt.Just[float]], r: Sequence[opt.Just[float]] | None = None) -> _Matrix[np.float64]: ...
63+
@overload
64+
def hankel(c: Sequence[opt.Just[complex]], r: Sequence[opt.Just[complex]] | None = None) -> _Matrix[np.complex128]: ...
65+
@overload
66+
def hankel(c: onp.CanArrayND[_SCT] | Sequence[_SCT], r: onp.CanArrayND[_SCT] | None = None) -> _Matrix[_SCT]: ...
67+
68+
#
69+
@overload
70+
def hadamard(n: onp.ToInt, dtype: type[opt.JustInt]) -> _Matrix[np.int_]: ...
71+
@overload
72+
def hadamard(n: onp.ToInt, dtype: type[opt.Just[float]]) -> _Matrix[np.float64]: ...
73+
@overload
74+
def hadamard(n: onp.ToInt, dtype: type[opt.Just[complex]]) -> _Matrix[np.complex128]: ...
75+
@overload
76+
def hadamard(n: onp.ToInt, dtype: onp.HasDType[np.dtype[_SCT]] | np.dtype[_SCT]) -> _Matrix[_SCT]: ...
77+
@overload
78+
def hadamard(n: onp.ToInt, dtype: npt.DTypeLike = ...) -> _Matrix[np.generic]: ...
79+
80+
#
81+
@overload
82+
def leslie(f: Sequence[opt.JustInt], s: Sequence[opt.JustInt]) -> _Matrix[np.int_]: ...
83+
@overload
84+
def leslie(f: Sequence[opt.Just[float]], s: Sequence[opt.Just[float]]) -> _Matrix[np.float64]: ...
85+
@overload
86+
def leslie(f: Sequence[opt.Just[complex]], s: Sequence[opt.Just[complex]]) -> _Matrix[np.complex128]: ...
87+
@overload
88+
def leslie(f: onp.CanArrayND[_SCT] | Sequence[_SCT], s: onp.CanArrayND[_SCT] | Sequence[_SCT]) -> _Matrix[_SCT]: ...
89+
90+
#
91+
def kron(a: onp.ArrayND[_SCT], b: onp.ArrayND[_SCT]) -> _Matrix[_SCT]: ...
92+
93+
#
94+
@overload
95+
def block_diag() -> _Matrix[np.float64]: ...
96+
@overload
97+
def block_diag(arr0: Sequence[opt.JustInt], /, *arrs: Sequence[opt.JustInt]) -> _Matrix[np.int_]: ...
98+
@overload
99+
def block_diag(arr0: Sequence[opt.Just[float]], /, *arrs: Sequence[opt.Just[float]]) -> _Matrix[np.float64]: ...
100+
@overload
101+
def block_diag(arr0: Sequence[opt.Just[complex]], /, *arrs: Sequence[opt.Just[complex]]) -> _Matrix[np.complex128]: ...
102+
@overload
103+
def block_diag(arr0: onp.CanArrayND[_SCT] | Sequence[_SCT], /, *arrs: onp.CanArrayND[_SCT] | Sequence[_SCT]) -> _Matrix[_SCT]: ...
104+
105+
#
106+
@overload
107+
def companion(a: Sequence[opt.JustInt]) -> _Matrix[np.int_]: ...
108+
@overload
109+
def companion(a: Sequence[opt.Just[float]]) -> _Matrix[np.float64]: ...
110+
@overload
111+
def companion(a: Sequence[opt.Just[complex]]) -> _Matrix[np.complex128]: ...
112+
@overload
113+
def companion(a: onp.CanArrayND[_SCT] | Sequence[_SCT]) -> _Matrix[_SCT]: ...
114+
115+
#
45116
def helmert(n: onp.ToInt, full: bool = False) -> _Matrix[np.float64]: ...
117+
118+
#
46119
def hilbert(n: onp.ToInt) -> _Matrix[np.float64]: ...
47-
def fiedler(a: npt.ArrayLike) -> _Matrix: ...
48-
def fiedler_companion(a: npt.ArrayLike) -> _Matrix: ...
49-
def convolution_matrix(a: npt.ArrayLike, n: onp.ToInt, mode: CorrelateMode = "full") -> _Matrix: ...
120+
121+
#
122+
@overload
123+
def fiedler(a: Sequence[opt.JustInt]) -> _Matrix[np.int_]: ...
124+
@overload
125+
def fiedler(a: Sequence[opt.Just[float]]) -> _Matrix[np.float64]: ...
126+
@overload
127+
def fiedler(a: Sequence[opt.Just[complex]]) -> _Matrix[np.complex128]: ...
128+
@overload
129+
def fiedler(a: onp.CanArrayND[_SCT] | Sequence[_SCT]) -> _Matrix[_SCT]: ...
130+
131+
#
132+
@overload
133+
def fiedler_companion(a: Sequence[opt.JustInt]) -> _Matrix[np.int_]: ...
134+
@overload
135+
def fiedler_companion(a: Sequence[opt.Just[float]]) -> _Matrix[np.float64]: ...
136+
@overload
137+
def fiedler_companion(a: Sequence[opt.Just[complex]]) -> _Matrix[np.complex128]: ...
138+
@overload
139+
def fiedler_companion(a: onp.CanArrayND[_SCT] | Sequence[_SCT]) -> _Matrix[_SCT]: ...
140+
141+
# TODO
142+
@overload
143+
def convolution_matrix(a: Sequence[opt.JustInt], n: onp.ToInt, mode: CorrelateMode = "full") -> _Matrix[np.int_]: ...
144+
@overload
145+
def convolution_matrix(a: Sequence[opt.Just[float]], n: onp.ToInt, mode: CorrelateMode = "full") -> _Matrix[np.float64]: ...
146+
@overload
147+
def convolution_matrix(a: Sequence[opt.Just[complex]], n: onp.ToInt, mode: CorrelateMode = "full") -> _Matrix[np.complex128]: ...
148+
@overload
149+
def convolution_matrix(a: onp.CanArrayND[_SCT] | Sequence[_SCT], n: onp.ToInt, mode: CorrelateMode = "full") -> _Matrix[_SCT]: ...
50150

51151
#
52152
@overload
53153
def invhilbert(n: onp.ToInt, exact: Literal[False] = False) -> _Matrix[np.float64]: ...
54154
@overload
55155
def invhilbert(n: onp.ToInt, exact: Literal[True]) -> _Matrix[np.int64] | _Matrix[np.object_]: ...
156+
157+
#
56158
@overload
57159
def pascal(n: onp.ToInt, kind: _Kind = "symmetric", exact: Literal[True] = True) -> _Matrix[np.uint64 | np.object_]: ...
58160
@overload
59161
def pascal(n: onp.ToInt, kind: _Kind = "symmetric", *, exact: Literal[False]) -> _Matrix[np.float64]: ...
60162
@overload
61163
def pascal(n: onp.ToInt, kind: _Kind, exact: Literal[False]) -> _Matrix[np.float64]: ...
164+
165+
#
62166
@overload
63167
def invpascal(n: onp.ToInt, kind: _Kind = "symmetric", exact: Literal[True] = True) -> _Matrix[np.int64 | np.object_]: ...
64168
@overload
65169
def invpascal(n: onp.ToInt, kind: _Kind = "symmetric", *, exact: Literal[False]) -> _Matrix[np.float64]: ...
66170
@overload
67171
def invpascal(n: onp.ToInt, kind: _Kind, exact: Literal[False]) -> _Matrix[np.float64]: ...
172+
173+
#
68174
def dft(n: onp.ToInt, scale: Literal["sqrtn", "n"] | None = None) -> _Matrix[np.complex128]: ...

0 commit comments

Comments
 (0)