Skip to content

Commit caaeb1f

Browse files
authored
signal: improve resample_poly (#829)
2 parents 0e725d0 + 279afc3 commit caaeb1f

File tree

2 files changed

+119
-32
lines changed

2 files changed

+119
-32
lines changed

scipy-stubs/signal/_signaltools.pyi

Lines changed: 91 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# mypy: disable-error-code=overload-overlap
55

6-
from collections.abc import Callable
6+
from collections.abc import Callable, Sequence
77
from typing import Any, Literal as L, TypeAlias, TypeVar, TypedDict, overload, type_check_only
88

99
import numpy as np
@@ -57,7 +57,8 @@ __all__ = [
5757
_T = TypeVar("_T")
5858
_InexactT = TypeVar("_InexactT", bound=npc.inexact)
5959
_NumericT = TypeVar("_NumericT", bound=npc.number | np.bool_)
60-
_InexactStandardT = TypeVar("_InexactStandardT", bound=np.float32 | np.float64 | npc.floating80 | npc.complexfloating)
60+
_InexactT2 = TypeVar("_InexactT2", bound=np.float32 | np.float64 | np.complex64 | np.complex128)
61+
_InexactT3 = TypeVar("_InexactT3", bound=np.float32 | np.float64 | npc.floating80 | npc.complexfloating)
6162
_CoFloat64T = TypeVar("_CoFloat64T", bound=np.float64 | np.float32 | npc.integer)
6263
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
6364

@@ -334,11 +335,11 @@ def fftconvolve(
334335
) -> onp.ArrayND[np.float32, _AnyShapeT]: ...
335336
@overload # generic dtype, generic, shape
336337
def fftconvolve(
337-
in1: onp.ArrayND[_InexactStandardT, _AnyShapeT],
338-
in2: onp.ArrayND[_InexactStandardT, _AnyShapeT],
338+
in1: onp.ArrayND[_InexactT3, _AnyShapeT],
339+
in2: onp.ArrayND[_InexactT3, _AnyShapeT],
339340
mode: onp.ConvolveMode = "full",
340341
axes: None = None,
341-
) -> onp.ArrayND[_InexactStandardT, _AnyShapeT]: ...
342+
) -> onp.ArrayND[_InexactT3, _AnyShapeT]: ...
342343
@overload # ~float64, +float64
343344
def fftconvolve(
344345
in1: onp.ToJustFloat64_ND, in2: onp.ToFloat64_ND, mode: onp.ConvolveMode = "full", axes: AnyShape | None = None
@@ -370,11 +371,11 @@ def oaconvolve(
370371
) -> onp.ArrayND[np.float32, _AnyShapeT]: ...
371372
@overload # generic dtype, generic, shape
372373
def oaconvolve(
373-
in1: onp.ArrayND[_InexactStandardT, _AnyShapeT],
374-
in2: onp.ArrayND[_InexactStandardT, _AnyShapeT],
374+
in1: onp.ArrayND[_InexactT3, _AnyShapeT],
375+
in2: onp.ArrayND[_InexactT3, _AnyShapeT],
375376
mode: onp.ConvolveMode = "full",
376377
axes: None = None,
377-
) -> onp.ArrayND[_InexactStandardT, _AnyShapeT]: ...
378+
) -> onp.ArrayND[_InexactT3, _AnyShapeT]: ...
378379
@overload # ~float64, +float64
379380
def oaconvolve(
380381
in1: onp.ToJustFloat64_ND, in2: onp.ToFloat64_ND, mode: onp.ConvolveMode = "full", axes: AnyShape | None = None
@@ -1025,22 +1026,22 @@ def invresz(
10251026
# the (hypothetical) `AnyOf[np.float64, np.complex128]` gradual type.
10261027
@overload # known dtype, known shape, t=None (default)
10271028
def resample(
1028-
x: nptc.CanArray[_ShapeT, np.dtype[_InexactStandardT]],
1029+
x: nptc.CanArray[_ShapeT, np.dtype[_InexactT3]],
10291030
num: int,
10301031
t: None = None,
10311032
axis: int = 0,
10321033
window: _ToResampleWindow[_AnyInexact64T] | None = None,
10331034
domain: _Domain = "time",
1034-
) -> onp.ArrayND[_InexactStandardT, _ShapeT]: ...
1035+
) -> onp.ArrayND[_InexactT3, _ShapeT]: ...
10351036
@overload # known dtype, known shape, t=<given>
10361037
def resample(
1037-
x: nptc.CanArray[_ShapeT, np.dtype[_InexactStandardT]],
1038+
x: nptc.CanArray[_ShapeT, np.dtype[_InexactT3]],
10381039
num: int,
10391040
t: onp.ToFloat1D,
10401041
axis: int = 0,
10411042
window: _ToResampleWindow[_AnyInexact64T] | None = None,
10421043
domain: _Domain = "time",
1043-
) -> tuple[onp.ArrayND[_InexactStandardT, _ShapeT], onp.Array1D[np.float64]]: ...
1044+
) -> tuple[onp.ArrayND[_InexactT3, _ShapeT], onp.Array1D[np.float64]]: ...
10441045
@overload # +integer, known shape, t=None (default)
10451046
def resample(
10461047
x: nptc.CanArray[_ShapeT, np.dtype[npc.integer | np.bool_]],
@@ -1068,7 +1069,7 @@ def resample(
10681069
window: _ToResampleWindow[np.float64] | None = None,
10691070
domain: _Domain = "time",
10701071
) -> onp.ArrayND[np.float32, _ShapeT]: ...
1071-
@overload # ~float16, unknown shape, t=<given>
1072+
@overload # ~float16, known shape, t=<given>
10721073
def resample(
10731074
x: nptc.CanArray[_ShapeT, np.dtype[np.float16]],
10741075
num: int,
@@ -1077,7 +1078,7 @@ def resample(
10771078
window: _ToResampleWindow[np.float64] | None = None,
10781079
domain: _Domain = "time",
10791080
) -> tuple[onp.ArrayND[np.float32, _ShapeT], onp.Array1D[np.float64]]: ...
1080-
@overload # ~float64 | +integer, unknown shape, t=None (default)
1081+
@overload # +float, unknown shape, t=None (default)
10811082
def resample(
10821083
x: onp.SequenceND[float],
10831084
num: int,
@@ -1086,7 +1087,7 @@ def resample(
10861087
window: _ToResampleWindow[np.float64] | None = None,
10871088
domain: _Domain = "time",
10881089
) -> onp.ArrayND[np.float64]: ...
1089-
@overload # ~float64 | +integer, unknown shape, t=<given>
1090+
@overload # +float, unknown shape, t=<given>
10901091
def resample(
10911092
x: onp.SequenceND[float],
10921093
num: int,
@@ -1095,7 +1096,7 @@ def resample(
10951096
window: _ToResampleWindow[np.float64] | None = None,
10961097
domain: _Domain = "time",
10971098
) -> tuple[onp.ArrayND[np.float64], onp.Array1D[np.float64]]: ...
1098-
@overload # ~complex128, unknown shape, t=None (default)
1099+
@overload # ~complex, unknown shape, t=None (default)
10991100
def resample(
11001101
x: onp.SequenceND[op.JustComplex | np.complex128],
11011102
num: int,
@@ -1104,7 +1105,7 @@ def resample(
11041105
window: _ToResampleWindow[np.complex128] | None = None,
11051106
domain: _Domain = "time",
11061107
) -> onp.ArrayND[np.complex128]: ...
1107-
@overload # ~complex128, unknown shape, t=<given>
1108+
@overload # ~complex, unknown shape, t=<given>
11081109
def resample(
11091110
x: onp.SequenceND[op.JustComplex | np.complex128],
11101111
num: int,
@@ -1132,27 +1133,87 @@ def resample(
11321133
domain: _Domain = "time",
11331134
) -> tuple[onp.ArrayND[Any, _WorkaroundForPyright], onp.Array1D[np.float64]]: ...
11341135

1135-
# TODO(jorenham): improve
1136-
@overload
1136+
# NOTE: This does not support the (useless) `up == down` case, which at runtime can return ANY dtype.
1137+
@overload # known dtype, known shape
11371138
def resample_poly(
1138-
x: onp.ToFloatND,
1139+
x: nptc.CanArray[_ShapeT, np.dtype[_InexactT2]],
11391140
up: int,
11401141
down: int,
11411142
axis: int = 0,
11421143
window: _ToWindow = ("kaiser", 5.0),
11431144
padtype: _PadType = "constant",
11441145
cval: float | None = None,
1145-
) -> onp.ArrayND[_F16_64]: ...
1146-
@overload
1146+
) -> onp.ArrayND[_InexactT2, _ShapeT]: ...
1147+
@overload # +integer, known shape
11471148
def resample_poly(
1148-
x: onp.ToComplexND,
1149+
x: nptc.CanArray[_ShapeT, np.dtype[npc.integer | np.bool_]],
11491150
up: int,
11501151
down: int,
11511152
axis: int = 0,
11521153
window: _ToWindow = ("kaiser", 5.0),
11531154
padtype: _PadType = "constant",
11541155
cval: float | None = None,
1155-
) -> onp.ArrayND[_C64_128 | _F16_64]: ...
1156+
) -> onp.ArrayND[np.float64, _ShapeT]: ...
1157+
@overload # ~float16, known shape
1158+
def resample_poly(
1159+
x: nptc.CanArray[_ShapeT, np.dtype[np.float16]],
1160+
up: int,
1161+
down: int,
1162+
axis: int = 0,
1163+
window: _ToWindow = ("kaiser", 5.0),
1164+
padtype: _PadType = "constant",
1165+
cval: float | None = None,
1166+
) -> onp.ArrayND[np.float32, _ShapeT]: ...
1167+
@overload # +float, 1d
1168+
def resample_poly(
1169+
x: Sequence[float],
1170+
up: int,
1171+
down: int,
1172+
axis: int = 0,
1173+
window: _ToWindow = ("kaiser", 5.0),
1174+
padtype: _PadType = "constant",
1175+
cval: float | None = None,
1176+
) -> onp.Array1D[np.float64]: ...
1177+
@overload # +float, unknown shape
1178+
def resample_poly(
1179+
x: onp.SequenceND[float],
1180+
up: int,
1181+
down: int,
1182+
axis: int = 0,
1183+
window: _ToWindow = ("kaiser", 5.0),
1184+
padtype: _PadType = "constant",
1185+
cval: float | None = None,
1186+
) -> onp.ArrayND[np.float64]: ...
1187+
@overload # ~complex, 1d
1188+
def resample_poly(
1189+
x: Sequence[op.JustComplex | np.complex128],
1190+
up: int,
1191+
down: int,
1192+
axis: int = 0,
1193+
window: _ToWindow = ("kaiser", 5.0),
1194+
padtype: _PadType = "constant",
1195+
cval: float | None = None,
1196+
) -> onp.Array1D[np.complex128]: ...
1197+
@overload # ~complex, unknown shape
1198+
def resample_poly(
1199+
x: onp.SequenceND[op.JustComplex | np.complex128],
1200+
up: int,
1201+
down: int,
1202+
axis: int = 0,
1203+
window: _ToWindow = ("kaiser", 5.0),
1204+
padtype: _PadType = "constant",
1205+
cval: float | None = None,
1206+
) -> onp.ArrayND[np.complex128]: ...
1207+
@overload # unknown dtype, unknown shape
1208+
def resample_poly(
1209+
x: onp.ToComplex128_ND,
1210+
up: int,
1211+
down: int,
1212+
axis: int = 0,
1213+
window: _ToWindow = ("kaiser", 5.0),
1214+
padtype: _PadType = "constant",
1215+
cval: float | None = None,
1216+
) -> onp.ArrayND[Any, _WorkaroundForPyright]: ...
11561217

11571218
# TODO(jorenham): improve
11581219
@overload
@@ -1207,31 +1268,31 @@ def envelope(
12071268
) -> onp.ArrayND[np.float32]: ...
12081269
@overload
12091270
def envelope(
1210-
z: onp.Array1D[_InexactStandardT],
1271+
z: onp.Array1D[_InexactT3],
12111272
bp_in: tuple[int | None, int | None] = (1, None),
12121273
*,
12131274
n_out: int | None = None,
12141275
squared: bool = False,
12151276
residual: _ResidualKind | None = "lowpass",
12161277
axis: int = -1,
1217-
) -> onp.Array2D[_InexactStandardT]: ...
1278+
) -> onp.Array2D[_InexactT3]: ...
12181279
@overload
12191280
def envelope(
1220-
z: onp.Array2D[_InexactStandardT],
1281+
z: onp.Array2D[_InexactT3],
12211282
bp_in: tuple[int | None, int | None] = (1, None),
12221283
*,
12231284
n_out: int | None = None,
12241285
squared: bool = False,
12251286
residual: _ResidualKind | None = "lowpass",
12261287
axis: int = -1,
1227-
) -> onp.Array3D[_InexactStandardT]: ...
1288+
) -> onp.Array3D[_InexactT3]: ...
12281289
@overload
12291290
def envelope(
1230-
z: onp.ArrayND[_InexactStandardT],
1291+
z: onp.ArrayND[_InexactT3],
12311292
bp_in: tuple[int | None, int | None] = (1, None),
12321293
*,
12331294
n_out: int | None = None,
12341295
squared: bool = False,
12351296
residual: _ResidualKind | None = "lowpass",
12361297
axis: int = -1,
1237-
) -> onp.ArrayND[_InexactStandardT]: ...
1298+
) -> onp.ArrayND[_InexactT3]: ...

tests/signal/test_resample.pyi

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
# type-tests for `resample` from `signal/_signaltools.pyi`
1+
# type-tests for `resample` and `resample_poly` from `signal/_signaltools.pyi`
22

33
from typing import assert_type
44

55
import numpy as np
66
import optype.numpy as onp
77

8-
from scipy.signal import resample
8+
from scipy.signal import resample, resample_poly
99

1010
num: int
1111

@@ -34,6 +34,7 @@ c160_1d: onp.Array1D[np.complex256]
3434
c160_2d: onp.Array2D[np.complex256]
3535

3636
###
37+
# resample
3738

3839
assert_type(resample(py_i_1d, num), onp.ArrayND[np.float64])
3940
assert_type(resample(py_f_1d, num), onp.ArrayND[np.float64])
@@ -56,3 +57,28 @@ assert_type(resample(f80_2d, num), onp.Array2D[np.float128])
5657
assert_type(resample(c64_2d, num), onp.Array2D[np.complex64])
5758
assert_type(resample(c128_2d, num), onp.Array2D[np.complex128])
5859
assert_type(resample(c160_2d, num), onp.Array2D[np.complex256])
60+
61+
###
62+
# resample_poly
63+
64+
assert_type(resample_poly(py_i_1d, num, num), onp.Array1D[np.float64])
65+
assert_type(resample_poly(py_f_1d, num, num), onp.Array1D[np.float64])
66+
assert_type(resample_poly(i8_1d, num, num), onp.Array1D[np.float64])
67+
assert_type(resample_poly(f16_1d, num, num), onp.Array1D[np.float32])
68+
assert_type(resample_poly(f32_1d, num, num), onp.Array1D[np.float32])
69+
assert_type(resample_poly(f64_1d, num, num), onp.Array1D[np.float64])
70+
resample_poly(f80_1d, num, num) # type: ignore[type-var] # pyright: ignore[reportArgumentType, reportCallIssue]
71+
assert_type(resample_poly(c64_1d, num, num), onp.Array1D[np.complex64])
72+
assert_type(resample_poly(c128_1d, num, num), onp.Array1D[np.complex128])
73+
resample_poly(c160_1d, num, num) # type: ignore[type-var] # pyright: ignore[reportArgumentType, reportCallIssue]
74+
75+
assert_type(resample_poly(py_i_2d, num, num), onp.ArrayND[np.float64])
76+
assert_type(resample_poly(py_f_2d, num, num), onp.ArrayND[np.float64])
77+
assert_type(resample_poly(i8_2d, num, num), onp.Array2D[np.float64])
78+
assert_type(resample_poly(f16_2d, num, num), onp.Array2D[np.float32])
79+
assert_type(resample_poly(f32_2d, num, num), onp.Array2D[np.float32])
80+
resample_poly(f80_2d, num, num) # type: ignore[type-var] # pyright: ignore[reportArgumentType, reportCallIssue]
81+
assert_type(resample_poly(f64_2d, num, num), onp.Array2D[np.float64])
82+
assert_type(resample_poly(c64_2d, num, num), onp.Array2D[np.complex64])
83+
assert_type(resample_poly(c128_2d, num, num), onp.Array2D[np.complex128])
84+
resample_poly(c160_2d, num, num) # type: ignore[type-var] # pyright: ignore[reportArgumentType, reportCallIssue]

0 commit comments

Comments
 (0)