Skip to content

Commit 12e1b1b

Browse files
authored
fft: sync (private) signatures of _realtransforms and _realtransforms_backend (#684)
2 parents f98448b + 8190dfb commit 12e1b1b

File tree

1 file changed

+98
-6
lines changed

1 file changed

+98
-6
lines changed
Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,110 @@
1-
from typing import TypeAlias
1+
from typing import TypeVar, overload
22

33
import numpy as np
44
import optype as op
55
import optype.numpy as onp
6+
import optype.numpy.compat as npc
67

78
from ._realtransforms import dct, dctn, dst, dstn, idct, idst
89
from scipy._typing import AnyShape, DCTType, NormalizationMode
910

1011
__all__ = ["dct", "dctn", "dst", "dstn", "idct", "idctn", "idst", "idstn"]
1112

12-
_RealND: TypeAlias = onp.ArrayND[np.float32 | np.float64 | np.longdouble]
13+
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
14+
_DTypeT = TypeVar("_DTypeT", bound=np.dtype[np.float32 | np.float64 | np.longdouble | npc.complexfloating])
1315

1416
# NOTE: Unlike the ones in `scipy.fft._realtransforms`, `orthogonalize` is keyword-only here.
17+
18+
#
19+
@overload
1520
def idctn(
16-
x: onp.ToComplexND,
21+
x: onp.CanArrayND[npc.integer, _ShapeT],
22+
type: DCTType = 2,
23+
s: onp.ToInt | onp.ToIntND | None = None,
24+
axes: AnyShape | None = None,
25+
norm: NormalizationMode | None = None,
26+
overwrite_x: op.CanBool = False,
27+
workers: onp.ToInt | None = None,
28+
*,
29+
orthogonalize: op.CanBool | None = None,
30+
) -> onp.Array[_ShapeT, np.float64]: ...
31+
@overload
32+
def idctn(
33+
x: onp.CanArrayND[np.float16, _ShapeT],
34+
type: DCTType = 2,
35+
s: onp.ToInt | onp.ToIntND | None = None,
36+
axes: AnyShape | None = None,
37+
norm: NormalizationMode | None = None,
38+
overwrite_x: op.CanBool = False,
39+
workers: onp.ToInt | None = None,
40+
*,
41+
orthogonalize: op.CanBool | None = None,
42+
) -> onp.Array[_ShapeT, np.float32]: ...
43+
@overload
44+
def idctn(
45+
x: onp.ToJustFloat64_ND,
46+
type: DCTType = 2,
47+
s: onp.ToInt | onp.ToIntND | None = None,
48+
axes: AnyShape | None = None,
49+
norm: NormalizationMode | None = None,
50+
overwrite_x: op.CanBool = False,
51+
workers: onp.ToInt | None = None,
52+
*,
53+
orthogonalize: op.CanBool | None = None,
54+
) -> onp.ArrayND[np.float64]: ...
55+
@overload
56+
def idctn(
57+
x: onp.ToFloatND,
58+
type: DCTType = 2,
59+
s: onp.ToInt | onp.ToIntND | None = None,
60+
axes: AnyShape | None = None,
61+
norm: NormalizationMode | None = None,
62+
overwrite_x: op.CanBool = False,
63+
workers: onp.ToInt | None = None,
64+
*,
65+
orthogonalize: op.CanBool | None = None,
66+
) -> onp.ArrayND[npc.floating]: ...
67+
68+
#
69+
@overload
70+
def idstn(
71+
x: onp.CanArrayND[npc.integer, _ShapeT],
72+
type: DCTType = 2,
73+
s: onp.ToInt | onp.ToIntND | None = None,
74+
axes: AnyShape | None = None,
75+
norm: NormalizationMode | None = None,
76+
overwrite_x: op.CanBool = False,
77+
workers: onp.ToInt | None = None,
78+
*,
79+
orthogonalize: op.CanBool | None = None,
80+
) -> onp.Array[_ShapeT, np.float64]: ...
81+
@overload
82+
def idstn(
83+
x: onp.CanArrayND[np.float16, _ShapeT],
84+
type: DCTType = 2,
85+
s: onp.ToInt | onp.ToIntND | None = None,
86+
axes: AnyShape | None = None,
87+
norm: NormalizationMode | None = None,
88+
overwrite_x: op.CanBool = False,
89+
workers: onp.ToInt | None = None,
90+
*,
91+
orthogonalize: op.CanBool | None = None,
92+
) -> onp.Array[_ShapeT, np.float32]: ...
93+
@overload
94+
def idstn(
95+
x: onp.CanArray[_ShapeT, _DTypeT],
96+
type: DCTType = 2,
97+
s: onp.ToInt | onp.ToIntND | None = None,
98+
axes: AnyShape | None = None,
99+
norm: NormalizationMode | None = None,
100+
overwrite_x: op.CanBool = False,
101+
workers: onp.ToInt | None = None,
102+
*,
103+
orthogonalize: op.CanBool | None = None,
104+
) -> np.ndarray[_ShapeT, _DTypeT]: ...
105+
@overload
106+
def idstn(
107+
x: onp.ToJustFloat64_ND,
17108
type: DCTType = 2,
18109
s: onp.ToInt | onp.ToIntND | None = None,
19110
axes: AnyShape | None = None,
@@ -22,9 +113,10 @@ def idctn(
22113
workers: onp.ToInt | None = None,
23114
*,
24115
orthogonalize: op.CanBool | None = None,
25-
) -> _RealND: ...
116+
) -> onp.ArrayND[np.float64]: ...
117+
@overload
26118
def idstn(
27-
x: onp.ToComplexND,
119+
x: onp.ToFloatND,
28120
type: DCTType = 2,
29121
s: onp.ToInt | onp.ToIntND | None = None,
30122
axes: AnyShape | None = None,
@@ -33,4 +125,4 @@ def idstn(
33125
workers: onp.ToInt | None = None,
34126
*,
35127
orthogonalize: op.CanBool | None = None,
36-
) -> _RealND: ...
128+
) -> onp.ArrayND[npc.floating]: ...

0 commit comments

Comments
 (0)