Skip to content

Commit 48daa90

Browse files
authored
ndimage: improved fourier_* annotations (#740)
2 parents e396b6f + 1d9b2ac commit 48daa90

File tree

2 files changed

+277
-63
lines changed

2 files changed

+277
-63
lines changed

scipy-stubs/ndimage/_fourier.pyi

Lines changed: 184 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,202 @@
1-
from typing import TypeVar, overload
1+
from typing import Any, SupportsFloat as CanFloat, SupportsIndex as CanIndex, TypeAlias, TypeVar, overload
22

33
import numpy as np
44
import optype.numpy as onp
55

66
__all__ = ["fourier_ellipsoid", "fourier_gaussian", "fourier_shift", "fourier_uniform"]
77

8-
_FloatArrayOutT = TypeVar("_FloatArrayOutT", bound=onp.ArrayND[np.float64 | np.float32])
9-
_ComplexArrayOutT = TypeVar("_ComplexArrayOutT", bound=onp.ArrayND[np.complex128 | np.float64 | np.complex64 | np.float32])
8+
###
109

11-
#
12-
@overload
10+
_ScalarComplex: TypeAlias = np.complex64 | np.complex128
11+
_OutputScalarComplexT = TypeVar("_OutputScalarComplexT", bound=_ScalarComplex)
12+
_OutputArrayComplexT = TypeVar("_OutputArrayComplexT", bound=onp.ArrayND[_ScalarComplex])
13+
14+
_Scalar: TypeAlias = np.float32 | np.float64 | _ScalarComplex
15+
_OutputScalarT = TypeVar("_OutputScalarT", bound=_Scalar)
16+
_OutputArrayT = TypeVar("_OutputArrayT", bound=onp.ArrayND[_Scalar])
17+
18+
_Sigma: TypeAlias = CanFloat | onp.ToFloat1D
19+
_AsF32: TypeAlias = np.float16 | np.float32
20+
_InputF64: TypeAlias = onp.ToJustFloat64_ND | onp.ToIntND
21+
_InputC128: TypeAlias = onp.ToJustComplex128_ND # pro forma
22+
# these *should* be equivalent to `onp.ArrayND[Never, {}]`, but both mypy and pyright currently have bugs with `Never` unions
23+
_InputF32: TypeAlias = onp.CanArrayND[_AsF32] | onp.SequenceND[onp.CanArray[Any, np.dtype[_AsF32]]]
24+
_InputC64: TypeAlias = onp.CanArrayND[np.complex64] | onp.SequenceND[onp.CanArray[Any, np.dtype[np.complex64]]]
25+
26+
###
27+
# NOTE: The gaussian, uniform, and ellipsoid function signatures are equivalent (except for the 2nd *name*): Keep them in sync!
28+
# NOTE: The [overload-overlap] mypy errors are false positives (probably a union/join thing).
29+
30+
# undocumented
31+
@overload # output: <T: ndarray>
32+
def _get_output_fourier(output: _OutputArrayT, input: onp.ToComplex128_ND) -> _OutputArrayT: ...
33+
@overload # output: <T: scalar type>
34+
def _get_output_fourier(output: type[_OutputScalarT], input: onp.ToComplex128_ND) -> onp.ArrayND[_OutputScalarT]: ...
35+
@overload # +float32
36+
def _get_output_fourier(output: None, input: _InputF32) -> onp.ArrayND[np.float32]: ... # type: ignore[overload-overlap]
37+
@overload # +float64
38+
def _get_output_fourier(output: None, input: _InputF64) -> onp.ArrayND[np.float64]: ...
39+
@overload # ~complex64
40+
def _get_output_fourier(output: None, input: _InputC64) -> onp.ArrayND[np.complex64]: ... # type: ignore[overload-overlap]
41+
@overload # ~complex128
42+
def _get_output_fourier(output: None, input: _InputC128) -> onp.ArrayND[np.complex128]: ...
43+
@overload # fallback
44+
def _get_output_fourier(output: None, input: onp.ToComplex128_ND) -> onp.ArrayND[_Scalar]: ...
45+
46+
# undocumented
47+
@overload # output: complex64 array or scalar-type
48+
def _get_output_fourier_complex( # type: ignore[overload-overlap]
49+
output: onp.ArrayND[np.complex64] | type[np.complex64], input: onp.ToComplex128_ND
50+
) -> onp.ArrayND[np.complex64]: ...
51+
@overload # output: complex128 array or scalar-type
52+
def _get_output_fourier_complex(
53+
output: onp.ArrayND[np.complex128] | type[np.complex128], input: onp.ToComplex128_ND
54+
) -> onp.ArrayND[np.complex128]: ...
55+
@overload # ~complex64
56+
def _get_output_fourier_complex(output: None, input: _InputC64) -> onp.ArrayND[np.complex64]: ... # type: ignore[overload-overlap]
57+
@overload # ~complex128 | +floating
58+
def _get_output_fourier_complex(output: None, input: _InputC128 | onp.ToFloat64_ND) -> onp.ArrayND[np.complex128]: ...
59+
@overload # fallback
60+
def _get_output_fourier_complex(output: None, input: onp.ToComplex128_ND) -> onp.ArrayND[_ScalarComplex]: ...
61+
62+
# NOTE: Keep in sync with `fourier_uniform` and `fourier_ellipsoid` (but note the different 2nd parameter names)
63+
@overload # output: <T: ndarray> (positional)
1364
def fourier_gaussian(
14-
input: _FloatArrayOutT | onp.ToFloat | onp.ToFloatND,
15-
sigma: onp.ToFloat | onp.ToFloatND,
16-
n: onp.ToInt = -1,
17-
axis: onp.ToInt = -1,
18-
output: _FloatArrayOutT | None = None,
19-
) -> _FloatArrayOutT: ...
20-
@overload
65+
input: onp.ToComplex128_ND, sigma: _Sigma, n: CanIndex, axis: int, output: _OutputArrayT
66+
) -> _OutputArrayT: ...
67+
@overload # output: <T: ndarray> (keyword)
2168
def fourier_gaussian(
22-
input: _ComplexArrayOutT | onp.ToComplex | onp.ToComplexND,
23-
sigma: onp.ToFloat | onp.ToFloatND,
24-
n: onp.ToInt = -1,
25-
axis: onp.ToInt = -1,
26-
output: _ComplexArrayOutT | None = None,
27-
) -> _ComplexArrayOutT: ...
69+
input: onp.ToComplex128_ND, sigma: _Sigma, n: CanIndex = -1, axis: int = -1, *, output: _OutputArrayT
70+
) -> _OutputArrayT: ...
71+
@overload # output: <T: scalar type> (positional)
72+
def fourier_gaussian(
73+
input: onp.ToComplex128_ND, sigma: _Sigma, n: CanIndex, axis: int, output: type[_OutputScalarT]
74+
) -> onp.ArrayND[_OutputScalarT]: ...
75+
@overload # output: <T: scalar type> (keyword)
76+
def fourier_gaussian(
77+
input: onp.ToComplex128_ND, sigma: _Sigma, n: CanIndex = -1, axis: int = -1, *, output: type[_OutputScalarT]
78+
) -> onp.ArrayND[_OutputScalarT]: ...
79+
@overload # +float32
80+
def fourier_gaussian( # type: ignore[overload-overlap]
81+
input: _InputF32, sigma: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
82+
) -> onp.ArrayND[np.float32]: ...
83+
@overload # +float64
84+
def fourier_gaussian(
85+
input: _InputF64, sigma: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
86+
) -> onp.ArrayND[np.float64]: ...
87+
@overload # ~complex64
88+
def fourier_gaussian( # type: ignore[overload-overlap]
89+
input: _InputC64, sigma: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
90+
) -> onp.ArrayND[np.complex64]: ...
91+
@overload # ~complex128
92+
def fourier_gaussian(
93+
input: _InputC128, sigma: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
94+
) -> onp.ArrayND[np.complex128]: ...
95+
@overload # fallback
96+
def fourier_gaussian(
97+
input: onp.ToComplex128_ND, sigma: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
98+
) -> onp.ArrayND[_Scalar]: ...
2899

29-
#
30-
@overload
100+
# NOTE: Keep in sync with `fourier_ellipsoid` and `fourier_gaussian` (but note the different 2nd parameter name)
101+
@overload # output: <T: ndarray> (positional)
102+
def fourier_uniform(input: onp.ToComplex128_ND, size: _Sigma, n: CanIndex, axis: int, output: _OutputArrayT) -> _OutputArrayT: ...
103+
@overload # output: <T: ndarray> (keyword)
104+
def fourier_uniform(
105+
input: onp.ToComplex128_ND, size: _Sigma, n: CanIndex = -1, axis: int = -1, *, output: _OutputArrayT
106+
) -> _OutputArrayT: ...
107+
@overload # output: <T: scalar type> (positional)
31108
def fourier_uniform(
32-
input: _FloatArrayOutT | onp.ToFloat | onp.ToFloatND,
33-
size: onp.ToFloat | onp.ToFloatND,
34-
n: onp.ToInt = -1,
35-
axis: onp.ToInt = -1,
36-
output: _FloatArrayOutT | None = None,
37-
) -> _FloatArrayOutT: ...
38-
@overload
109+
input: onp.ToComplex128_ND, size: _Sigma, n: CanIndex, axis: int, output: type[_OutputScalarT]
110+
) -> onp.ArrayND[_OutputScalarT]: ...
111+
@overload # output: <T: scalar type> (keyword)
39112
def fourier_uniform(
40-
input: _ComplexArrayOutT | onp.ToComplex | onp.ToComplexND,
41-
size: onp.ToFloat | onp.ToFloatND,
42-
n: onp.ToInt = -1,
43-
axis: onp.ToInt = -1,
44-
output: _ComplexArrayOutT | None = None,
45-
) -> _ComplexArrayOutT: ...
113+
input: onp.ToComplex128_ND, size: _Sigma, n: CanIndex = -1, axis: int = -1, *, output: type[_OutputScalarT]
114+
) -> onp.ArrayND[_OutputScalarT]: ...
115+
@overload # +float32
116+
def fourier_uniform( # type: ignore[overload-overlap]
117+
input: _InputF32, size: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
118+
) -> onp.ArrayND[np.float32]: ...
119+
@overload # +float64
120+
def fourier_uniform(
121+
input: _InputF64, size: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
122+
) -> onp.ArrayND[np.float64]: ...
123+
@overload # ~complex64
124+
def fourier_uniform( # type: ignore[overload-overlap]
125+
input: _InputC64, size: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
126+
) -> onp.ArrayND[np.complex64]: ...
127+
@overload # ~complex128
128+
def fourier_uniform(
129+
input: _InputC128, size: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
130+
) -> onp.ArrayND[np.complex128]: ...
131+
@overload # fallback
132+
def fourier_uniform(
133+
input: onp.ToComplex128_ND, size: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
134+
) -> onp.ArrayND[_Scalar]: ...
46135

47-
#
48-
@overload
136+
# NOTE: Keep in sync with `fourier_uniform` and `fourier_gaussian` (but note the different 2nd parameter name)
137+
@overload # output: <T: ndarray> (positional)
138+
def fourier_ellipsoid(
139+
input: onp.ToComplex128_ND, size: _Sigma, n: CanIndex, axis: int, output: _OutputArrayT
140+
) -> _OutputArrayT: ...
141+
@overload # output: <T: ndarray> (keyword)
142+
def fourier_ellipsoid(
143+
input: onp.ToComplex128_ND, size: _Sigma, n: CanIndex = -1, axis: int = -1, *, output: _OutputArrayT
144+
) -> _OutputArrayT: ...
145+
@overload # output: <T: scalar type> (positional)
146+
def fourier_ellipsoid(
147+
input: onp.ToComplex128_ND, size: _Sigma, n: CanIndex, axis: int, output: type[_OutputScalarT]
148+
) -> onp.ArrayND[_OutputScalarT]: ...
149+
@overload # output: <T: scalar type> (keyword)
49150
def fourier_ellipsoid(
50-
input: _FloatArrayOutT | onp.ToFloat | onp.ToFloatND,
51-
size: onp.ToFloat | onp.ToFloatND,
52-
n: onp.ToInt = -1,
53-
axis: onp.ToInt = -1,
54-
output: _FloatArrayOutT | None = None,
55-
) -> _FloatArrayOutT: ...
56-
@overload
151+
input: onp.ToComplex128_ND, size: _Sigma, n: CanIndex = -1, axis: int = -1, *, output: type[_OutputScalarT]
152+
) -> onp.ArrayND[_OutputScalarT]: ...
153+
@overload # +float32
154+
def fourier_ellipsoid( # type: ignore[overload-overlap]
155+
input: _InputF32, size: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
156+
) -> onp.ArrayND[np.float32]: ...
157+
@overload # +float64
57158
def fourier_ellipsoid(
58-
input: _ComplexArrayOutT | onp.ToComplex | onp.ToComplexND,
59-
size: onp.ToFloat | onp.ToFloatND,
60-
n: onp.ToInt = -1,
61-
axis: onp.ToInt = -1,
62-
output: _ComplexArrayOutT | None = None,
63-
) -> _ComplexArrayOutT: ...
159+
input: _InputF64, size: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
160+
) -> onp.ArrayND[np.float64]: ...
161+
@overload # ~complex64
162+
def fourier_ellipsoid( # type: ignore[overload-overlap]
163+
input: _InputC64, size: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
164+
) -> onp.ArrayND[np.complex64]: ...
165+
@overload # ~complex128
166+
def fourier_ellipsoid(
167+
input: _InputC128, size: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
168+
) -> onp.ArrayND[np.complex128]: ...
169+
@overload # fallback
170+
def fourier_ellipsoid(
171+
input: onp.ToComplex128_ND, size: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
172+
) -> onp.ArrayND[_Scalar]: ...
64173

65-
#
66-
@overload
174+
# NOTE: Unlike the other three functions, this always returns complex output
175+
@overload # output: <T: ndarray> (positional)
176+
def fourier_shift(
177+
input: onp.ToComplex128_ND, shift: _Sigma, n: CanIndex, axis: int, output: _OutputArrayComplexT
178+
) -> _OutputArrayComplexT: ...
179+
@overload # output: <T: ndarray> (keyword)
180+
def fourier_shift(
181+
input: onp.ToComplex128_ND, shift: _Sigma, n: CanIndex = -1, axis: int = -1, *, output: _OutputArrayComplexT
182+
) -> _OutputArrayComplexT: ...
183+
@overload # output: <T: scalar type> (positional)
184+
def fourier_shift(
185+
input: onp.ToComplex128_ND, shift: _Sigma, n: CanIndex, axis: int, output: type[_OutputScalarComplexT]
186+
) -> onp.ArrayND[_OutputScalarComplexT]: ...
187+
@overload # output: <T: scalar type> (keyword)
188+
def fourier_shift(
189+
input: onp.ToComplex128_ND, shift: _Sigma, n: CanIndex = -1, axis: int = -1, *, output: type[_OutputScalarComplexT]
190+
) -> onp.ArrayND[_OutputScalarComplexT]: ...
191+
@overload # ~complex64
192+
def fourier_shift( # type: ignore[overload-overlap]
193+
input: _InputC64, shift: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
194+
) -> onp.ArrayND[np.complex64]: ...
195+
@overload # ~complex128 | +floating
67196
def fourier_shift(
68-
input: _FloatArrayOutT | onp.ToFloat | onp.ToFloatND,
69-
shift: onp.ToFloat | onp.ToFloatND,
70-
n: onp.ToInt = -1,
71-
axis: onp.ToInt = -1,
72-
output: _FloatArrayOutT | None = None,
73-
) -> _FloatArrayOutT: ...
74-
@overload
197+
input: _InputC128 | onp.ToFloat64_ND, shift: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
198+
) -> onp.ArrayND[np.complex128]: ...
199+
@overload # fallback
75200
def fourier_shift(
76-
input: _ComplexArrayOutT | onp.ToComplex | onp.ToComplexND,
77-
shift: onp.ToFloat | onp.ToFloatND,
78-
n: onp.ToInt = -1,
79-
axis: onp.ToInt = -1,
80-
output: _ComplexArrayOutT | None = None,
81-
) -> _ComplexArrayOutT: ...
201+
input: onp.ToComplex128_ND, shift: _Sigma, n: CanIndex = -1, axis: int = -1, output: None = None
202+
) -> onp.ArrayND[_ScalarComplex]: ...

tests/ndimage/test__fourier.pyi

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# type-tests for `ndimage/_fourier.pyi`
2+
3+
from typing import TypeAlias, assert_type
4+
5+
import numpy as np
6+
import numpy.typing as npt
7+
import optype.numpy as onp
8+
9+
from scipy.ndimage import fourier_gaussian, fourier_shift
10+
11+
i8_nd: npt.NDArray[np.int8]
12+
f16_nd: npt.NDArray[np.float16]
13+
f32_nd: npt.NDArray[np.float32]
14+
f64_nd: npt.NDArray[np.float64]
15+
c64_nd: npt.NDArray[np.complex64]
16+
c128_nd: npt.NDArray[np.complex128]
17+
18+
int_2d: list[list[int]]
19+
float_2d: list[list[float]]
20+
complex_2d: list[list[complex]]
21+
22+
_OutputArray: TypeAlias = onp.Array2D[np.complex64]
23+
_OutputArrayND: TypeAlias = onp.ArrayND[np.complex64]
24+
output_array: _OutputArray
25+
output_sctype: type[np.complex64]
26+
27+
###
28+
# `fourier_gaussian` (also `fourier_ellipsoid` and `fourier_uniform`)
29+
# NOTE: `fourier_uniform` and `fourier_ellipsoid` have the same signature, so no need to also test those.
30+
31+
assert_type(fourier_gaussian(i8_nd, 4), onp.ArrayND[np.float64])
32+
assert_type(fourier_gaussian(f16_nd, 4), onp.ArrayND[np.float32])
33+
assert_type(fourier_gaussian(f32_nd, 4), onp.ArrayND[np.float32])
34+
assert_type(fourier_gaussian(f64_nd, 4), onp.ArrayND[np.float64])
35+
assert_type(fourier_gaussian(c64_nd, 4), onp.ArrayND[np.complex64])
36+
assert_type(fourier_gaussian(c128_nd, 4), onp.ArrayND[np.complex128])
37+
assert_type(fourier_gaussian(int_2d, 4), onp.ArrayND[np.float64])
38+
assert_type(fourier_gaussian(float_2d, 4), onp.ArrayND[np.float64])
39+
assert_type(fourier_gaussian(complex_2d, 4), onp.ArrayND[np.complex128])
40+
41+
assert_type(fourier_gaussian(i8_nd, 4, output=output_array), _OutputArray)
42+
assert_type(fourier_gaussian(f16_nd, 4, output=output_array), _OutputArray)
43+
assert_type(fourier_gaussian(f32_nd, 4, output=output_array), _OutputArray)
44+
assert_type(fourier_gaussian(f64_nd, 4, output=output_array), _OutputArray)
45+
assert_type(fourier_gaussian(c64_nd, 4, output=output_array), _OutputArray)
46+
assert_type(fourier_gaussian(c128_nd, 4, output=output_array), _OutputArray)
47+
assert_type(fourier_gaussian(int_2d, 4, output=output_array), _OutputArray)
48+
assert_type(fourier_gaussian(float_2d, 4, output=output_array), _OutputArray)
49+
assert_type(fourier_gaussian(complex_2d, 4, output=output_array), _OutputArray)
50+
51+
assert_type(fourier_gaussian(i8_nd, 4, output=output_sctype), _OutputArrayND)
52+
assert_type(fourier_gaussian(f16_nd, 4, output=output_sctype), _OutputArrayND)
53+
assert_type(fourier_gaussian(f32_nd, 4, output=output_sctype), _OutputArrayND)
54+
assert_type(fourier_gaussian(f64_nd, 4, output=output_sctype), _OutputArrayND)
55+
assert_type(fourier_gaussian(c64_nd, 4, output=output_sctype), _OutputArrayND)
56+
assert_type(fourier_gaussian(c128_nd, 4, output=output_sctype), _OutputArrayND)
57+
assert_type(fourier_gaussian(int_2d, 4, output=output_sctype), _OutputArrayND)
58+
assert_type(fourier_gaussian(float_2d, 4, output=output_sctype), _OutputArrayND)
59+
assert_type(fourier_gaussian(complex_2d, 4, output=output_sctype), _OutputArrayND)
60+
61+
###
62+
# `fourier_shift`
63+
# NOTE: Unlike the other three functions, this always returns complex output.
64+
65+
assert_type(fourier_shift(i8_nd, 4), onp.ArrayND[np.complex128])
66+
assert_type(fourier_shift(f16_nd, 4), onp.ArrayND[np.complex128])
67+
assert_type(fourier_shift(f32_nd, 4), onp.ArrayND[np.complex128])
68+
assert_type(fourier_shift(f64_nd, 4), onp.ArrayND[np.complex128])
69+
assert_type(fourier_shift(c64_nd, 4), onp.ArrayND[np.complex64])
70+
assert_type(fourier_shift(c128_nd, 4), onp.ArrayND[np.complex128])
71+
assert_type(fourier_shift(int_2d, 4), onp.ArrayND[np.complex128])
72+
assert_type(fourier_shift(float_2d, 4), onp.ArrayND[np.complex128])
73+
assert_type(fourier_shift(complex_2d, 4), onp.ArrayND[np.complex128])
74+
75+
assert_type(fourier_shift(i8_nd, 4, output=output_array), _OutputArray)
76+
assert_type(fourier_shift(f16_nd, 4, output=output_array), _OutputArray)
77+
assert_type(fourier_shift(f32_nd, 4, output=output_array), _OutputArray)
78+
assert_type(fourier_shift(f64_nd, 4, output=output_array), _OutputArray)
79+
assert_type(fourier_shift(c64_nd, 4, output=output_array), _OutputArray)
80+
assert_type(fourier_shift(c128_nd, 4, output=output_array), _OutputArray)
81+
assert_type(fourier_shift(int_2d, 4, output=output_array), _OutputArray)
82+
assert_type(fourier_shift(float_2d, 4, output=output_array), _OutputArray)
83+
assert_type(fourier_shift(complex_2d, 4, output=output_array), _OutputArray)
84+
85+
assert_type(fourier_shift(i8_nd, 4, output=output_sctype), _OutputArrayND)
86+
assert_type(fourier_shift(f16_nd, 4, output=output_sctype), _OutputArrayND)
87+
assert_type(fourier_shift(f32_nd, 4, output=output_sctype), _OutputArrayND)
88+
assert_type(fourier_shift(f64_nd, 4, output=output_sctype), _OutputArrayND)
89+
assert_type(fourier_shift(c64_nd, 4, output=output_sctype), _OutputArrayND)
90+
assert_type(fourier_shift(c128_nd, 4, output=output_sctype), _OutputArrayND)
91+
assert_type(fourier_shift(int_2d, 4, output=output_sctype), _OutputArrayND)
92+
assert_type(fourier_shift(float_2d, 4, output=output_sctype), _OutputArrayND)
93+
assert_type(fourier_shift(complex_2d, 4, output=output_sctype), _OutputArrayND)

0 commit comments

Comments
 (0)