Skip to content

Commit 88a4151

Browse files
authored
linalg: improved svd annotations (#700)
2 parents 9c447a3 + 2ba937e commit 88a4151

File tree

2 files changed

+94
-18
lines changed

2 files changed

+94
-18
lines changed

scipy-stubs/linalg/_decomp_svd.pyi

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,69 +4,91 @@ from typing_extensions import TypeVar
44
import numpy as np
55
import optype as op
66
import optype.numpy as onp
7+
import optype.numpy.compat as npc
78

89
from scipy._typing import Falsy, Truthy
910

1011
__all__ = ["diagsvd", "null_space", "orth", "subspace_angles", "svd", "svdvals"]
1112

13+
_T = TypeVar("_T")
14+
_Tuple3: TypeAlias = tuple[_T, _T, _T]
15+
1216
_Float: TypeAlias = np.float32 | np.float64
1317
_FloatND: TypeAlias = onp.ArrayND[_Float]
1418

1519
_Complex: TypeAlias = np.complex64 | np.complex128
16-
_ComplexND: TypeAlias = onp.ArrayND[_Complex]
1720

1821
_LapackDriver: TypeAlias = Literal["gesdd", "gesvd"]
1922

20-
_FloatSVD: TypeAlias = tuple[_FloatND, _FloatND, _FloatND]
21-
_ComplexSVD: TypeAlias = tuple[_ComplexND, _FloatND, _ComplexND]
22-
2323
_RealT = TypeVar("_RealT", bound=np.bool_ | np.integer[Any] | np.floating[Any])
2424
_InexactT = TypeVar("_InexactT", bound=_Float | _Complex)
2525

26+
_as_f32: TypeAlias = np.float32 | np.float16 # noqa: PYI042
27+
_as_f64: TypeAlias = np.longdouble | np.float64 | npc.integer | np.bool_ # noqa: PYI042
28+
2629
###
2730

28-
@overload
31+
@overload # nd float64
2932
def svd(
30-
a: onp.ToFloatND,
33+
a: onp.ToArrayND[float, _as_f64],
3134
full_matrices: onp.ToBool = True,
3235
compute_uv: Truthy = True,
3336
overwrite_a: onp.ToBool = False,
3437
check_finite: onp.ToBool = True,
3538
lapack_driver: _LapackDriver = "gesdd",
36-
) -> _FloatSVD: ...
37-
@overload
39+
) -> _Tuple3[onp.ArrayND[np.float64]]: ...
40+
@overload # nd float32
41+
def svd(
42+
a: onp.ToArrayND[_as_f32, _as_f32],
43+
full_matrices: onp.ToBool = True,
44+
compute_uv: Truthy = True,
45+
overwrite_a: onp.ToBool = False,
46+
check_finite: onp.ToBool = True,
47+
lapack_driver: _LapackDriver = "gesdd",
48+
) -> _Tuple3[onp.ArrayND[np.float32]]: ...
49+
@overload # nd complex128
3850
def svd(
39-
a: onp.ToComplexND,
51+
a: onp.ToArrayND[op.JustComplex, np.complex128 | np.clongdouble],
4052
full_matrices: onp.ToBool = True,
4153
compute_uv: Truthy = True,
4254
overwrite_a: onp.ToBool = False,
4355
check_finite: onp.ToBool = True,
4456
lapack_driver: _LapackDriver = "gesdd",
45-
) -> _FloatSVD | _ComplexSVD: ...
46-
@overload # complex, compute_uv: {False}
57+
) -> tuple[onp.ArrayND[np.complex128], onp.ArrayND[np.float64], onp.ArrayND[np.complex128]]: ...
58+
@overload # nd complex64
4759
def svd(
48-
a: onp.ToComplexND,
49-
full_matrices: onp.ToBool,
60+
a: onp.ToArrayND[np.complex64, np.complex64],
61+
full_matrices: onp.ToBool = True,
62+
compute_uv: Truthy = True,
63+
overwrite_a: onp.ToBool = False,
64+
check_finite: onp.ToBool = True,
65+
lapack_driver: _LapackDriver = "gesdd",
66+
) -> tuple[onp.ArrayND[np.complex64], onp.ArrayND[np.float32], onp.ArrayND[np.complex64]]: ...
67+
@overload # nd float64 | complex128, compute_uv=False (keyword)
68+
def svd(
69+
a: onp.ToArrayND[complex, _as_f64 | np.complex128 | np.clongdouble],
70+
full_matrices: onp.ToBool = True,
71+
*,
5072
compute_uv: Falsy,
5173
overwrite_a: onp.ToBool = False,
5274
check_finite: onp.ToBool = True,
5375
lapack_driver: _LapackDriver = "gesdd",
54-
) -> _FloatND: ...
55-
@overload # complex, *, compute_uv: {False}
76+
) -> onp.ArrayND[np.float64]: ...
77+
@overload # nd float32 | complex64, compute_uv=False (keyword)
5678
def svd(
57-
a: onp.ToComplexND,
79+
a: onp.ToArrayND[_as_f32, _as_f32 | np.complex64],
5880
full_matrices: onp.ToBool = True,
5981
*,
6082
compute_uv: Falsy,
6183
overwrite_a: onp.ToBool = False,
6284
check_finite: onp.ToBool = True,
6385
lapack_driver: _LapackDriver = "gesdd",
64-
) -> _FloatND: ...
86+
) -> onp.ArrayND[np.float32]: ...
6587

6688
#
6789
def svdvals(a: onp.ToComplexND, overwrite_a: onp.ToBool = False, check_finite: onp.ToBool = True) -> _FloatND: ...
6890

69-
# beware the overlapping overloads for bool <: int (<: float)
91+
#
7092
@overload
7193
def diagsvd(s: onp.SequenceND[_RealT] | onp.CanArrayND[_RealT], M: op.CanIndex, N: op.CanIndex) -> onp.ArrayND[_RealT]: ...
7294
@overload

tests/linalg/test_decomp_svd.pyi

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from typing import TypeAlias, assert_type
2+
3+
import numpy as np
4+
import numpy.typing as npt
5+
6+
from scipy.linalg import svd
7+
8+
ArrayF32: TypeAlias = npt.NDArray[np.float32]
9+
ArrayF64: TypeAlias = npt.NDArray[np.float64]
10+
ArrayC64: TypeAlias = npt.NDArray[np.complex64]
11+
ArrayC128: TypeAlias = npt.NDArray[np.complex128]
12+
13+
###
14+
15+
py_i_2d: list[list[int]]
16+
py_f_2d: list[list[float]]
17+
py_c_2d: list[list[complex]]
18+
19+
f16_nd: npt.NDArray[np.float16]
20+
f32_nd: npt.NDArray[np.float32]
21+
f64_nd: npt.NDArray[np.float64]
22+
f80_nd: npt.NDArray[np.longdouble]
23+
24+
c64_nd: npt.NDArray[np.complex64]
25+
c128_nd: npt.NDArray[np.complex128]
26+
c160_nd: npt.NDArray[np.clongdouble]
27+
28+
###
29+
# svd
30+
31+
assert_type(svd(py_i_2d), tuple[ArrayF64, ArrayF64, ArrayF64])
32+
assert_type(svd(py_f_2d), tuple[ArrayF64, ArrayF64, ArrayF64])
33+
assert_type(svd(py_c_2d), tuple[ArrayC128, ArrayF64, ArrayC128])
34+
assert_type(svd(f16_nd), tuple[ArrayF32, ArrayF32, ArrayF32])
35+
assert_type(svd(f32_nd), tuple[ArrayF32, ArrayF32, ArrayF32])
36+
assert_type(svd(f64_nd), tuple[ArrayF64, ArrayF64, ArrayF64])
37+
assert_type(svd(f80_nd), tuple[ArrayF64, ArrayF64, ArrayF64])
38+
assert_type(svd(c64_nd), tuple[ArrayC64, ArrayF32, ArrayC64])
39+
assert_type(svd(c128_nd), tuple[ArrayC128, ArrayF64, ArrayC128])
40+
assert_type(svd(c160_nd), tuple[ArrayC128, ArrayF64, ArrayC128])
41+
42+
assert_type(svd(py_i_2d, compute_uv=False), ArrayF64)
43+
assert_type(svd(py_f_2d, compute_uv=False), ArrayF64)
44+
assert_type(svd(py_c_2d, compute_uv=False), ArrayF64)
45+
assert_type(svd(f16_nd, compute_uv=False), ArrayF32)
46+
assert_type(svd(f32_nd, compute_uv=False), ArrayF32)
47+
assert_type(svd(f64_nd, compute_uv=False), ArrayF64)
48+
assert_type(svd(f80_nd, compute_uv=False), ArrayF64)
49+
assert_type(svd(c64_nd, compute_uv=False), ArrayF32)
50+
assert_type(svd(c128_nd, compute_uv=False), ArrayF64)
51+
assert_type(svd(c160_nd, compute_uv=False), ArrayF64)
52+
53+
####
54+
# TODO: test the remaining functions in `_decomp_svd.pyi`

0 commit comments

Comments
 (0)