1
- from typing import Literal , TypeAlias , overload
2
- from typing_extensions import TypeVar
1
+ from typing import Literal , TypeAlias , TypeVar , overload
3
2
4
3
import numpy as np
5
4
import optype as op
@@ -8,87 +7,90 @@ import optype.numpy.compat as npc
8
7
9
8
__all__ = ["diagsvd" , "null_space" , "orth" , "subspace_angles" , "svd" , "svdvals" ]
10
9
11
- _T = TypeVar ("_T" )
12
- _Tuple3 : TypeAlias = tuple [_T , _T , _T ]
13
-
14
- _Float : TypeAlias = np .float32 | np .float64
15
- _FloatND : TypeAlias = onp .ArrayND [_Float ]
16
-
17
- _Complex : TypeAlias = np .complex64 | np .complex128
18
-
19
- _LapackDriver : TypeAlias = Literal ["gesdd" , "gesvd" ]
20
-
21
10
_RealT = TypeVar ("_RealT" , bound = np .bool_ | npc .integer | npc .floating )
22
11
_InexactT = TypeVar ("_InexactT" , bound = _Float | _Complex )
12
+ _ScalarT = TypeVar ("_ScalarT" , bound = np .generic )
13
+ _ScalarT1 = TypeVar ("_ScalarT1" , bound = np .generic )
14
+
15
+ _SVD_ND : TypeAlias = tuple [onp .ArrayND [_ScalarT ], onp .ArrayND [_ScalarT1 ], onp .ArrayND [_ScalarT ]]
16
+
17
+ _Float : TypeAlias = np .float64 | np .float32
18
+ _Complex : TypeAlias = np .complex128 | np .complex64
23
19
24
20
_as_f32 : TypeAlias = np .float32 | np .float16 # noqa: PYI042
25
21
_as_f64 : TypeAlias = np .longdouble | np .float64 | npc .integer | np .bool_ # noqa: PYI042
22
+ _as_c128 : TypeAlias = np .complex128 | np .clongdouble # noqa: PYI042
23
+
24
+ _ToSafeFloat64ND : TypeAlias = onp .ToArrayND [float , np .float64 | npc .integer | np .bool_ ]
25
+ _ToArrayND : TypeAlias = onp .CanArrayND [_ScalarT ] | onp .SequenceND [_ScalarT ]
26
+
27
+ _LapackDriver : TypeAlias = Literal ["gesdd" , "gesvd" ]
26
28
27
29
###
28
30
29
31
@overload # nd float64
30
32
def svd (
31
33
a : onp .ToArrayND [float , _as_f64 ],
32
- full_matrices : onp . ToBool = True ,
33
- compute_uv : onp . ToTrue = True ,
34
- overwrite_a : onp . ToBool = False ,
35
- check_finite : onp . ToBool = True ,
34
+ full_matrices : bool = True ,
35
+ compute_uv : Literal [ True ] = True ,
36
+ overwrite_a : bool = False ,
37
+ check_finite : bool = True ,
36
38
lapack_driver : _LapackDriver = "gesdd" ,
37
- ) -> _Tuple3 [ onp . ArrayND [ np .float64 ] ]: ...
39
+ ) -> _SVD_ND [ np . float64 , np .float64 ]: ...
38
40
@overload # nd float32
39
41
def svd (
40
- a : onp .ToArrayND [ _as_f32 , _as_f32 ],
41
- full_matrices : onp . ToBool = True ,
42
- compute_uv : onp . ToTrue = True ,
43
- overwrite_a : onp . ToBool = False ,
44
- check_finite : onp . ToBool = True ,
42
+ a : onp .CanArrayND [ _as_f32 ],
43
+ full_matrices : bool = True ,
44
+ compute_uv : Literal [ True ] = True ,
45
+ overwrite_a : bool = False ,
46
+ check_finite : bool = True ,
45
47
lapack_driver : _LapackDriver = "gesdd" ,
46
- ) -> _Tuple3 [ onp . ArrayND [ np .float32 ] ]: ...
48
+ ) -> _SVD_ND [ np . float32 , np .float32 ]: ...
47
49
@overload # nd complex128
48
50
def svd (
49
- a : onp .ToArrayND [op .JustComplex , np . complex128 | np . clongdouble ],
50
- full_matrices : onp . ToBool = True ,
51
- compute_uv : onp . ToTrue = True ,
52
- overwrite_a : onp . ToBool = False ,
53
- check_finite : onp . ToBool = True ,
51
+ a : onp .ToArrayND [op .JustComplex , _as_c128 ],
52
+ full_matrices : bool = True ,
53
+ compute_uv : Literal [ True ] = True ,
54
+ overwrite_a : bool = False ,
55
+ check_finite : bool = True ,
54
56
lapack_driver : _LapackDriver = "gesdd" ,
55
- ) -> tuple [ onp . ArrayND [ np .complex128 ], onp . ArrayND [ np .float64 ], onp . ArrayND [ np . complex128 ] ]: ...
57
+ ) -> _SVD_ND [ np .complex128 , np .float64 ]: ...
56
58
@overload # nd complex64
57
59
def svd (
58
- a : onp .ToArrayND [ np . complex64 , np .complex64 ],
59
- full_matrices : onp . ToBool = True ,
60
- compute_uv : onp . ToTrue = True ,
61
- overwrite_a : onp . ToBool = False ,
62
- check_finite : onp . ToBool = True ,
60
+ a : onp .CanArrayND [ np .complex64 ],
61
+ full_matrices : bool = True ,
62
+ compute_uv : Literal [ True ] = True ,
63
+ overwrite_a : bool = False ,
64
+ check_finite : bool = True ,
63
65
lapack_driver : _LapackDriver = "gesdd" ,
64
- ) -> tuple [ onp . ArrayND [ np .complex64 ], onp . ArrayND [ np .float32 ], onp . ArrayND [ np . complex64 ] ]: ...
66
+ ) -> _SVD_ND [ np .complex64 , np .float32 ]: ...
65
67
@overload # nd float64 | complex128, compute_uv=False (keyword)
66
68
def svd (
67
- a : onp .ToArrayND [complex , _as_f64 | np . complex128 | np . clongdouble ],
68
- full_matrices : onp . ToBool = True ,
69
+ a : onp .ToArrayND [complex , _as_f64 | _as_c128 ],
70
+ full_matrices : bool = True ,
69
71
* ,
70
- compute_uv : onp . ToFalse ,
71
- overwrite_a : onp . ToBool = False ,
72
- check_finite : onp . ToBool = True ,
72
+ compute_uv : Literal [ False ] ,
73
+ overwrite_a : bool = False ,
74
+ check_finite : bool = True ,
73
75
lapack_driver : _LapackDriver = "gesdd" ,
74
76
) -> onp .ArrayND [np .float64 ]: ...
75
77
@overload # nd float32 | complex64, compute_uv=False (keyword)
76
78
def svd (
77
- a : onp .ToArrayND [ _as_f32 , _as_f32 | np .complex64 ],
78
- full_matrices : onp . ToBool = True ,
79
+ a : onp .CanArrayND [ _as_f32 | np .complex64 ],
80
+ full_matrices : bool = True ,
79
81
* ,
80
- compute_uv : onp . ToFalse ,
81
- overwrite_a : onp . ToBool = False ,
82
- check_finite : onp . ToBool = True ,
82
+ compute_uv : Literal [ False ] ,
83
+ overwrite_a : bool = False ,
84
+ check_finite : bool = True ,
83
85
lapack_driver : _LapackDriver = "gesdd" ,
84
86
) -> onp .ArrayND [np .float32 ]: ...
85
87
86
88
#
87
- def svdvals (a : onp .ToComplexND , overwrite_a : onp . ToBool = False , check_finite : onp . ToBool = True ) -> _FloatND : ...
89
+ def svdvals (a : onp .ToComplexND , overwrite_a : bool = False , check_finite : bool = True ) -> onp . ArrayND [ np . float64 | np . float32 ] : ...
88
90
89
91
#
90
92
@overload
91
- def diagsvd (s : onp . SequenceND [ _RealT ] | onp . CanArrayND [_RealT ], M : op .CanIndex , N : op .CanIndex ) -> onp .ArrayND [_RealT ]: ...
93
+ def diagsvd (s : _ToArrayND [_RealT ], M : op .CanIndex , N : op .CanIndex ) -> onp .ArrayND [_RealT ]: ...
92
94
@overload
93
95
def diagsvd (s : onp .SequenceND [bool ], M : op .CanIndex , N : op .CanIndex ) -> onp .ArrayND [np .bool_ ]: ...
94
96
@overload
@@ -98,42 +100,40 @@ def diagsvd(s: onp.SequenceND[op.JustFloat], M: op.CanIndex, N: op.CanIndex) ->
98
100
99
101
#
100
102
@overload
101
- def orth (A : onp . ToIntND | onp . ToJustFloat64_ND , rcond : onp . ToFloat | None = None ) -> onp .ArrayND [np .float64 ]: ...
103
+ def orth (A : _ToSafeFloat64ND , rcond : float | None = None ) -> onp .ArrayND [np .float64 ]: ...
102
104
@overload
103
- def orth (A : onp .ToJustComplex128_ND , rcond : onp . ToFloat | None = None ) -> onp .ArrayND [np .complex128 ]: ...
105
+ def orth (A : onp .ToJustComplex128_ND , rcond : float | None = None ) -> onp .ArrayND [np .complex128 ]: ...
104
106
@overload
105
- def orth (
106
- A : onp .SequenceND [_InexactT ] | onp .CanArrayND [_InexactT ], rcond : onp .ToFloat | None = None
107
- ) -> onp .ArrayND [_InexactT ]: ...
107
+ def orth (A : _ToArrayND [_InexactT ], rcond : float | None = None ) -> onp .ArrayND [_InexactT ]: ...
108
108
109
109
#
110
110
@overload
111
111
def null_space (
112
- A : onp . ToIntND | onp . ToJustFloat64_ND ,
113
- rcond : onp . ToFloat | None = None ,
112
+ A : _ToSafeFloat64ND ,
113
+ rcond : float | None = None ,
114
114
* ,
115
- overwrite_a : onp . ToBool = False ,
116
- check_finite : onp . ToBool = True ,
115
+ overwrite_a : bool = False ,
116
+ check_finite : bool = True ,
117
117
lapack_driver : _LapackDriver = "gesdd" ,
118
118
) -> onp .ArrayND [np .float64 ]: ...
119
119
@overload
120
120
def null_space (
121
121
A : onp .ToJustComplex128_ND ,
122
- rcond : onp . ToFloat | None = None ,
122
+ rcond : float | None = None ,
123
123
* ,
124
- overwrite_a : onp . ToBool = False ,
125
- check_finite : onp . ToBool = True ,
124
+ overwrite_a : bool = False ,
125
+ check_finite : bool = True ,
126
126
lapack_driver : _LapackDriver = "gesdd" ,
127
127
) -> onp .ArrayND [np .complex128 ]: ...
128
128
@overload
129
129
def null_space (
130
- A : onp . SequenceND [ _InexactT ] | onp . CanArrayND [_InexactT ],
131
- rcond : onp . ToFloat | None = None ,
130
+ A : _ToArrayND [_InexactT ],
131
+ rcond : float | None = None ,
132
132
* ,
133
- overwrite_a : onp . ToBool = False ,
134
- check_finite : onp . ToBool = True ,
133
+ overwrite_a : bool = False ,
134
+ check_finite : bool = True ,
135
135
lapack_driver : _LapackDriver = "gesdd" ,
136
136
) -> onp .ArrayND [_InexactT ]: ...
137
137
138
138
#
139
- def subspace_angles (A : onp .ToComplexND , B : onp .ToComplexND ) -> _FloatND : ...
139
+ def subspace_angles (A : onp .ToComplexND , B : onp .ToComplexND ) -> onp . ArrayND [ np . float64 | np . float32 ] : ...
0 commit comments