1
- from typing import TypeAlias
1
+ from typing import TypeVar , overload
2
2
3
3
import numpy as np
4
4
import optype as op
5
5
import optype .numpy as onp
6
+ import optype .numpy .compat as npc
6
7
7
8
from ._realtransforms import dct , dctn , dst , dstn , idct , idst
8
9
from scipy ._typing import AnyShape , DCTType , NormalizationMode
9
10
10
11
__all__ = ["dct" , "dctn" , "dst" , "dstn" , "idct" , "idctn" , "idst" , "idstn" ]
11
12
12
- _RealND : TypeAlias = onp .ArrayND [np .float32 | np .float64 | np .longdouble ]
13
+ _ShapeT = TypeVar ("_ShapeT" , bound = tuple [int , ...])
14
+ _DTypeT = TypeVar ("_DTypeT" , bound = np .dtype [np .float32 | np .float64 | np .longdouble | npc .complexfloating ])
13
15
14
16
# NOTE: Unlike the ones in `scipy.fft._realtransforms`, `orthogonalize` is keyword-only here.
17
+
18
+ #
19
+ @overload
15
20
def idctn (
16
- x : onp .ToComplexND ,
21
+ x : onp .CanArrayND [npc .integer , _ShapeT ],
22
+ type : DCTType = 2 ,
23
+ s : onp .ToInt | onp .ToIntND | None = None ,
24
+ axes : AnyShape | None = None ,
25
+ norm : NormalizationMode | None = None ,
26
+ overwrite_x : op .CanBool = False ,
27
+ workers : onp .ToInt | None = None ,
28
+ * ,
29
+ orthogonalize : op .CanBool | None = None ,
30
+ ) -> onp .Array [_ShapeT , np .float64 ]: ...
31
+ @overload
32
+ def idctn (
33
+ x : onp .CanArrayND [np .float16 , _ShapeT ],
34
+ type : DCTType = 2 ,
35
+ s : onp .ToInt | onp .ToIntND | None = None ,
36
+ axes : AnyShape | None = None ,
37
+ norm : NormalizationMode | None = None ,
38
+ overwrite_x : op .CanBool = False ,
39
+ workers : onp .ToInt | None = None ,
40
+ * ,
41
+ orthogonalize : op .CanBool | None = None ,
42
+ ) -> onp .Array [_ShapeT , np .float32 ]: ...
43
+ @overload
44
+ def idctn (
45
+ x : onp .ToJustFloat64_ND ,
46
+ type : DCTType = 2 ,
47
+ s : onp .ToInt | onp .ToIntND | None = None ,
48
+ axes : AnyShape | None = None ,
49
+ norm : NormalizationMode | None = None ,
50
+ overwrite_x : op .CanBool = False ,
51
+ workers : onp .ToInt | None = None ,
52
+ * ,
53
+ orthogonalize : op .CanBool | None = None ,
54
+ ) -> onp .ArrayND [np .float64 ]: ...
55
+ @overload
56
+ def idctn (
57
+ x : onp .ToFloatND ,
58
+ type : DCTType = 2 ,
59
+ s : onp .ToInt | onp .ToIntND | None = None ,
60
+ axes : AnyShape | None = None ,
61
+ norm : NormalizationMode | None = None ,
62
+ overwrite_x : op .CanBool = False ,
63
+ workers : onp .ToInt | None = None ,
64
+ * ,
65
+ orthogonalize : op .CanBool | None = None ,
66
+ ) -> onp .ArrayND [npc .floating ]: ...
67
+
68
+ #
69
+ @overload
70
+ def idstn (
71
+ x : onp .CanArrayND [npc .integer , _ShapeT ],
72
+ type : DCTType = 2 ,
73
+ s : onp .ToInt | onp .ToIntND | None = None ,
74
+ axes : AnyShape | None = None ,
75
+ norm : NormalizationMode | None = None ,
76
+ overwrite_x : op .CanBool = False ,
77
+ workers : onp .ToInt | None = None ,
78
+ * ,
79
+ orthogonalize : op .CanBool | None = None ,
80
+ ) -> onp .Array [_ShapeT , np .float64 ]: ...
81
+ @overload
82
+ def idstn (
83
+ x : onp .CanArrayND [np .float16 , _ShapeT ],
84
+ type : DCTType = 2 ,
85
+ s : onp .ToInt | onp .ToIntND | None = None ,
86
+ axes : AnyShape | None = None ,
87
+ norm : NormalizationMode | None = None ,
88
+ overwrite_x : op .CanBool = False ,
89
+ workers : onp .ToInt | None = None ,
90
+ * ,
91
+ orthogonalize : op .CanBool | None = None ,
92
+ ) -> onp .Array [_ShapeT , np .float32 ]: ...
93
+ @overload
94
+ def idstn (
95
+ x : onp .CanArray [_ShapeT , _DTypeT ],
96
+ type : DCTType = 2 ,
97
+ s : onp .ToInt | onp .ToIntND | None = None ,
98
+ axes : AnyShape | None = None ,
99
+ norm : NormalizationMode | None = None ,
100
+ overwrite_x : op .CanBool = False ,
101
+ workers : onp .ToInt | None = None ,
102
+ * ,
103
+ orthogonalize : op .CanBool | None = None ,
104
+ ) -> np .ndarray [_ShapeT , _DTypeT ]: ...
105
+ @overload
106
+ def idstn (
107
+ x : onp .ToJustFloat64_ND ,
17
108
type : DCTType = 2 ,
18
109
s : onp .ToInt | onp .ToIntND | None = None ,
19
110
axes : AnyShape | None = None ,
@@ -22,9 +113,10 @@ def idctn(
22
113
workers : onp .ToInt | None = None ,
23
114
* ,
24
115
orthogonalize : op .CanBool | None = None ,
25
- ) -> _RealND : ...
116
+ ) -> onp .ArrayND [np .float64 ]: ...
117
+ @overload
26
118
def idstn (
27
- x : onp .ToComplexND ,
119
+ x : onp .ToFloatND ,
28
120
type : DCTType = 2 ,
29
121
s : onp .ToInt | onp .ToIntND | None = None ,
30
122
axes : AnyShape | None = None ,
@@ -33,4 +125,4 @@ def idstn(
33
125
workers : onp .ToInt | None = None ,
34
126
* ,
35
127
orthogonalize : op .CanBool | None = None ,
36
- ) -> _RealND : ...
128
+ ) -> onp . ArrayND [ npc . floating ] : ...
0 commit comments