@@ -4,69 +4,91 @@ from typing_extensions import TypeVar
4
4
import numpy as np
5
5
import optype as op
6
6
import optype .numpy as onp
7
+ import optype .numpy .compat as npc
7
8
8
9
from scipy ._typing import Falsy , Truthy
9
10
10
11
__all__ = ["diagsvd" , "null_space" , "orth" , "subspace_angles" , "svd" , "svdvals" ]
11
12
13
+ _T = TypeVar ("_T" )
14
+ _Tuple3 : TypeAlias = tuple [_T , _T , _T ]
15
+
12
16
_Float : TypeAlias = np .float32 | np .float64
13
17
_FloatND : TypeAlias = onp .ArrayND [_Float ]
14
18
15
19
_Complex : TypeAlias = np .complex64 | np .complex128
16
- _ComplexND : TypeAlias = onp .ArrayND [_Complex ]
17
20
18
21
_LapackDriver : TypeAlias = Literal ["gesdd" , "gesvd" ]
19
22
20
- _FloatSVD : TypeAlias = tuple [_FloatND , _FloatND , _FloatND ]
21
- _ComplexSVD : TypeAlias = tuple [_ComplexND , _FloatND , _ComplexND ]
22
-
23
23
_RealT = TypeVar ("_RealT" , bound = np .bool_ | np .integer [Any ] | np .floating [Any ])
24
24
_InexactT = TypeVar ("_InexactT" , bound = _Float | _Complex )
25
25
26
+ _as_f32 : TypeAlias = np .float32 | np .float16 # noqa: PYI042
27
+ _as_f64 : TypeAlias = np .longdouble | np .float64 | npc .integer | np .bool_ # noqa: PYI042
28
+
26
29
###
27
30
28
- @overload
31
+ @overload # nd float64
29
32
def svd (
30
- a : onp .ToFloatND ,
33
+ a : onp .ToArrayND [ float , _as_f64 ] ,
31
34
full_matrices : onp .ToBool = True ,
32
35
compute_uv : Truthy = True ,
33
36
overwrite_a : onp .ToBool = False ,
34
37
check_finite : onp .ToBool = True ,
35
38
lapack_driver : _LapackDriver = "gesdd" ,
36
- ) -> _FloatSVD : ...
37
- @overload
39
+ ) -> _Tuple3 [onp .ArrayND [np .float64 ]]: ...
40
+ @overload # nd float32
41
+ def svd (
42
+ a : onp .ToArrayND [_as_f32 , _as_f32 ],
43
+ full_matrices : onp .ToBool = True ,
44
+ compute_uv : Truthy = True ,
45
+ overwrite_a : onp .ToBool = False ,
46
+ check_finite : onp .ToBool = True ,
47
+ lapack_driver : _LapackDriver = "gesdd" ,
48
+ ) -> _Tuple3 [onp .ArrayND [np .float32 ]]: ...
49
+ @overload # nd complex128
38
50
def svd (
39
- a : onp .ToComplexND ,
51
+ a : onp .ToArrayND [ op . JustComplex , np . complex128 | np . clongdouble ] ,
40
52
full_matrices : onp .ToBool = True ,
41
53
compute_uv : Truthy = True ,
42
54
overwrite_a : onp .ToBool = False ,
43
55
check_finite : onp .ToBool = True ,
44
56
lapack_driver : _LapackDriver = "gesdd" ,
45
- ) -> _FloatSVD | _ComplexSVD : ...
46
- @overload # complex, compute_uv: {False}
57
+ ) -> tuple [ onp . ArrayND [ np . complex128 ], onp . ArrayND [ np . float64 ], onp . ArrayND [ np . complex128 ]] : ...
58
+ @overload # nd complex64
47
59
def svd (
48
- a : onp .ToComplexND ,
49
- full_matrices : onp .ToBool ,
60
+ a : onp .ToArrayND [np .complex64 , np .complex64 ],
61
+ full_matrices : onp .ToBool = True ,
62
+ compute_uv : Truthy = True ,
63
+ overwrite_a : onp .ToBool = False ,
64
+ check_finite : onp .ToBool = True ,
65
+ lapack_driver : _LapackDriver = "gesdd" ,
66
+ ) -> tuple [onp .ArrayND [np .complex64 ], onp .ArrayND [np .float32 ], onp .ArrayND [np .complex64 ]]: ...
67
+ @overload # nd float64 | complex128, compute_uv=False (keyword)
68
+ def svd (
69
+ a : onp .ToArrayND [complex , _as_f64 | np .complex128 | np .clongdouble ],
70
+ full_matrices : onp .ToBool = True ,
71
+ * ,
50
72
compute_uv : Falsy ,
51
73
overwrite_a : onp .ToBool = False ,
52
74
check_finite : onp .ToBool = True ,
53
75
lapack_driver : _LapackDriver = "gesdd" ,
54
- ) -> _FloatND : ...
55
- @overload # complex, * , compute_uv: { False}
76
+ ) -> onp . ArrayND [ np . float64 ] : ...
77
+ @overload # nd float32 | complex64 , compute_uv= False (keyword)
56
78
def svd (
57
- a : onp .ToComplexND ,
79
+ a : onp .ToArrayND [ _as_f32 , _as_f32 | np . complex64 ] ,
58
80
full_matrices : onp .ToBool = True ,
59
81
* ,
60
82
compute_uv : Falsy ,
61
83
overwrite_a : onp .ToBool = False ,
62
84
check_finite : onp .ToBool = True ,
63
85
lapack_driver : _LapackDriver = "gesdd" ,
64
- ) -> _FloatND : ...
86
+ ) -> onp . ArrayND [ np . float32 ] : ...
65
87
66
88
#
67
89
def svdvals (a : onp .ToComplexND , overwrite_a : onp .ToBool = False , check_finite : onp .ToBool = True ) -> _FloatND : ...
68
90
69
- # beware the overlapping overloads for bool <: int (<: float)
91
+ #
70
92
@overload
71
93
def diagsvd (s : onp .SequenceND [_RealT ] | onp .CanArrayND [_RealT ], M : op .CanIndex , N : op .CanIndex ) -> onp .ArrayND [_RealT ]: ...
72
94
@overload
0 commit comments