Skip to content

Commit b6bd9f5

Browse files
committed
🎨 linalg: simplified _decomp_svd type aliases
1 parent df7bada commit b6bd9f5

File tree

1 file changed

+64
-64
lines changed

1 file changed

+64
-64
lines changed
Lines changed: 64 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from typing import Literal, TypeAlias, overload
2-
from typing_extensions import TypeVar
1+
from typing import Literal, TypeAlias, TypeVar, overload
32

43
import numpy as np
54
import optype as op
@@ -8,87 +7,90 @@ import optype.numpy.compat as npc
87

98
__all__ = ["diagsvd", "null_space", "orth", "subspace_angles", "svd", "svdvals"]
109

11-
_T = TypeVar("_T")
12-
_Tuple3: TypeAlias = tuple[_T, _T, _T]
13-
14-
_Float: TypeAlias = np.float32 | np.float64
15-
_FloatND: TypeAlias = onp.ArrayND[_Float]
16-
17-
_Complex: TypeAlias = np.complex64 | np.complex128
18-
19-
_LapackDriver: TypeAlias = Literal["gesdd", "gesvd"]
20-
2110
_RealT = TypeVar("_RealT", bound=np.bool_ | npc.integer | npc.floating)
2211
_InexactT = TypeVar("_InexactT", bound=_Float | _Complex)
12+
_ScalarT = TypeVar("_ScalarT", bound=np.generic)
13+
_ScalarT1 = TypeVar("_ScalarT1", bound=np.generic)
14+
15+
_SVD_ND: TypeAlias = tuple[onp.ArrayND[_ScalarT], onp.ArrayND[_ScalarT1], onp.ArrayND[_ScalarT]]
16+
17+
_Float: TypeAlias = np.float64 | np.float32
18+
_Complex: TypeAlias = np.complex128 | np.complex64
2319

2420
_as_f32: TypeAlias = np.float32 | np.float16 # noqa: PYI042
2521
_as_f64: TypeAlias = np.longdouble | np.float64 | npc.integer | np.bool_ # noqa: PYI042
22+
_as_c128: TypeAlias = np.complex128 | np.clongdouble # noqa: PYI042
23+
24+
_ToSafeFloat64ND: TypeAlias = onp.ToArrayND[float, np.float64 | npc.integer | np.bool_]
25+
_ToArrayND: TypeAlias = onp.CanArrayND[_ScalarT] | onp.SequenceND[_ScalarT]
26+
27+
_LapackDriver: TypeAlias = Literal["gesdd", "gesvd"]
2628

2729
###
2830

2931
@overload # nd float64
3032
def svd(
3133
a: onp.ToArrayND[float, _as_f64],
32-
full_matrices: onp.ToBool = True,
33-
compute_uv: onp.ToTrue = True,
34-
overwrite_a: onp.ToBool = False,
35-
check_finite: onp.ToBool = True,
34+
full_matrices: bool = True,
35+
compute_uv: Literal[True] = True,
36+
overwrite_a: bool = False,
37+
check_finite: bool = True,
3638
lapack_driver: _LapackDriver = "gesdd",
37-
) -> _Tuple3[onp.ArrayND[np.float64]]: ...
39+
) -> _SVD_ND[np.float64, np.float64]: ...
3840
@overload # nd float32
3941
def svd(
40-
a: onp.ToArrayND[_as_f32, _as_f32],
41-
full_matrices: onp.ToBool = True,
42-
compute_uv: onp.ToTrue = True,
43-
overwrite_a: onp.ToBool = False,
44-
check_finite: onp.ToBool = True,
42+
a: onp.CanArrayND[_as_f32],
43+
full_matrices: bool = True,
44+
compute_uv: Literal[True] = True,
45+
overwrite_a: bool = False,
46+
check_finite: bool = True,
4547
lapack_driver: _LapackDriver = "gesdd",
46-
) -> _Tuple3[onp.ArrayND[np.float32]]: ...
48+
) -> _SVD_ND[np.float32, np.float32]: ...
4749
@overload # nd complex128
4850
def svd(
49-
a: onp.ToArrayND[op.JustComplex, np.complex128 | np.clongdouble],
50-
full_matrices: onp.ToBool = True,
51-
compute_uv: onp.ToTrue = True,
52-
overwrite_a: onp.ToBool = False,
53-
check_finite: onp.ToBool = True,
51+
a: onp.ToArrayND[op.JustComplex, _as_c128],
52+
full_matrices: bool = True,
53+
compute_uv: Literal[True] = True,
54+
overwrite_a: bool = False,
55+
check_finite: bool = True,
5456
lapack_driver: _LapackDriver = "gesdd",
55-
) -> tuple[onp.ArrayND[np.complex128], onp.ArrayND[np.float64], onp.ArrayND[np.complex128]]: ...
57+
) -> _SVD_ND[np.complex128, np.float64]: ...
5658
@overload # nd complex64
5759
def svd(
58-
a: onp.ToArrayND[np.complex64, np.complex64],
59-
full_matrices: onp.ToBool = True,
60-
compute_uv: onp.ToTrue = True,
61-
overwrite_a: onp.ToBool = False,
62-
check_finite: onp.ToBool = True,
60+
a: onp.CanArrayND[np.complex64],
61+
full_matrices: bool = True,
62+
compute_uv: Literal[True] = True,
63+
overwrite_a: bool = False,
64+
check_finite: bool = True,
6365
lapack_driver: _LapackDriver = "gesdd",
64-
) -> tuple[onp.ArrayND[np.complex64], onp.ArrayND[np.float32], onp.ArrayND[np.complex64]]: ...
66+
) -> _SVD_ND[np.complex64, np.float32]: ...
6567
@overload # nd float64 | complex128, compute_uv=False (keyword)
6668
def svd(
67-
a: onp.ToArrayND[complex, _as_f64 | np.complex128 | np.clongdouble],
68-
full_matrices: onp.ToBool = True,
69+
a: onp.ToArrayND[complex, _as_f64 | _as_c128],
70+
full_matrices: bool = True,
6971
*,
70-
compute_uv: onp.ToFalse,
71-
overwrite_a: onp.ToBool = False,
72-
check_finite: onp.ToBool = True,
72+
compute_uv: Literal[False],
73+
overwrite_a: bool = False,
74+
check_finite: bool = True,
7375
lapack_driver: _LapackDriver = "gesdd",
7476
) -> onp.ArrayND[np.float64]: ...
7577
@overload # nd float32 | complex64, compute_uv=False (keyword)
7678
def svd(
77-
a: onp.ToArrayND[_as_f32, _as_f32 | np.complex64],
78-
full_matrices: onp.ToBool = True,
79+
a: onp.CanArrayND[_as_f32 | np.complex64],
80+
full_matrices: bool = True,
7981
*,
80-
compute_uv: onp.ToFalse,
81-
overwrite_a: onp.ToBool = False,
82-
check_finite: onp.ToBool = True,
82+
compute_uv: Literal[False],
83+
overwrite_a: bool = False,
84+
check_finite: bool = True,
8385
lapack_driver: _LapackDriver = "gesdd",
8486
) -> onp.ArrayND[np.float32]: ...
8587

8688
#
87-
def svdvals(a: onp.ToComplexND, overwrite_a: onp.ToBool = False, check_finite: onp.ToBool = True) -> _FloatND: ...
89+
def svdvals(a: onp.ToComplexND, overwrite_a: bool = False, check_finite: bool = True) -> onp.ArrayND[np.float64 | np.float32]: ...
8890

8991
#
9092
@overload
91-
def diagsvd(s: onp.SequenceND[_RealT] | onp.CanArrayND[_RealT], M: op.CanIndex, N: op.CanIndex) -> onp.ArrayND[_RealT]: ...
93+
def diagsvd(s: _ToArrayND[_RealT], M: op.CanIndex, N: op.CanIndex) -> onp.ArrayND[_RealT]: ...
9294
@overload
9395
def diagsvd(s: onp.SequenceND[bool], M: op.CanIndex, N: op.CanIndex) -> onp.ArrayND[np.bool_]: ...
9496
@overload
@@ -98,42 +100,40 @@ def diagsvd(s: onp.SequenceND[op.JustFloat], M: op.CanIndex, N: op.CanIndex) ->
98100

99101
#
100102
@overload
101-
def orth(A: onp.ToIntND | onp.ToJustFloat64_ND, rcond: onp.ToFloat | None = None) -> onp.ArrayND[np.float64]: ...
103+
def orth(A: _ToSafeFloat64ND, rcond: float | None = None) -> onp.ArrayND[np.float64]: ...
102104
@overload
103-
def orth(A: onp.ToJustComplex128_ND, rcond: onp.ToFloat | None = None) -> onp.ArrayND[np.complex128]: ...
105+
def orth(A: onp.ToJustComplex128_ND, rcond: float | None = None) -> onp.ArrayND[np.complex128]: ...
104106
@overload
105-
def orth(
106-
A: onp.SequenceND[_InexactT] | onp.CanArrayND[_InexactT], rcond: onp.ToFloat | None = None
107-
) -> onp.ArrayND[_InexactT]: ...
107+
def orth(A: _ToArrayND[_InexactT], rcond: float | None = None) -> onp.ArrayND[_InexactT]: ...
108108

109109
#
110110
@overload
111111
def null_space(
112-
A: onp.ToIntND | onp.ToJustFloat64_ND,
113-
rcond: onp.ToFloat | None = None,
112+
A: _ToSafeFloat64ND,
113+
rcond: float | None = None,
114114
*,
115-
overwrite_a: onp.ToBool = False,
116-
check_finite: onp.ToBool = True,
115+
overwrite_a: bool = False,
116+
check_finite: bool = True,
117117
lapack_driver: _LapackDriver = "gesdd",
118118
) -> onp.ArrayND[np.float64]: ...
119119
@overload
120120
def null_space(
121121
A: onp.ToJustComplex128_ND,
122-
rcond: onp.ToFloat | None = None,
122+
rcond: float | None = None,
123123
*,
124-
overwrite_a: onp.ToBool = False,
125-
check_finite: onp.ToBool = True,
124+
overwrite_a: bool = False,
125+
check_finite: bool = True,
126126
lapack_driver: _LapackDriver = "gesdd",
127127
) -> onp.ArrayND[np.complex128]: ...
128128
@overload
129129
def null_space(
130-
A: onp.SequenceND[_InexactT] | onp.CanArrayND[_InexactT],
131-
rcond: onp.ToFloat | None = None,
130+
A: _ToArrayND[_InexactT],
131+
rcond: float | None = None,
132132
*,
133-
overwrite_a: onp.ToBool = False,
134-
check_finite: onp.ToBool = True,
133+
overwrite_a: bool = False,
134+
check_finite: bool = True,
135135
lapack_driver: _LapackDriver = "gesdd",
136136
) -> onp.ArrayND[_InexactT]: ...
137137

138138
#
139-
def subspace_angles(A: onp.ToComplexND, B: onp.ToComplexND) -> _FloatND: ...
139+
def subspace_angles(A: onp.ToComplexND, B: onp.ToComplexND) -> onp.ArrayND[np.float64 | np.float32]: ...

0 commit comments

Comments
 (0)