Skip to content

Commit 5ed42d4

Browse files
authored
stats: improve zscore and gzscore (#816)
2 parents 1656364 + 5ee69f5 commit 5ed42d4

File tree

2 files changed

+167
-13
lines changed

2 files changed

+167
-13
lines changed

scipy-stubs/stats/_stats_py.pyi

Lines changed: 97 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ from typing_extensions import NamedTuple, TypeVar, deprecated
66

77
import numpy as np
88
import numpy.typing as npt
9+
import numpy_typing_compat as nptc
910
import optype as op
1011
import optype.numpy as onp
1112
import optype.numpy.compat as npc
@@ -89,6 +90,8 @@ __all__ = [
8990

9091
_SCT = TypeVar("_SCT", bound=np.generic)
9192

93+
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
94+
_InexactT = TypeVar("_InexactT", bound=npc.inexact)
9295
_FloatT = TypeVar("_FloatT", bound=npc.floating, default=npc.floating)
9396
_RealT = TypeVar("_RealT", bound=_Real0D, default=_Real0D)
9497
_RealT_co = TypeVar("_RealT_co", bound=_Real0D, default=_Real0D, covariant=True)
@@ -684,17 +687,98 @@ def sem(
684687
a: onp.ToComplexND, axis: int | None = 0, ddof: int = 1, nan_policy: NanPolicy = "propagate", *, keepdims: bool = False
685688
) -> _FloatOrND: ...
686689

687-
# TODO(jorenham): improve
690+
# NOTE: keep in sync with `gzscore`
691+
@overload # +integer, known shape
692+
def zscore(
693+
a: nptc.CanArray[_ShapeT, np.dtype[npc.integer | np.bool_]],
694+
axis: int | None = 0,
695+
ddof: int = 0,
696+
nan_policy: NanPolicy = "propagate",
697+
) -> onp.ArrayND[np.float64, _ShapeT]: ...
698+
@overload # known inexact dtype, known shape
699+
def zscore(
700+
a: nptc.CanArray[_ShapeT, np.dtype[_InexactT]], axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
701+
) -> onp.ArrayND[_InexactT, _ShapeT]: ...
702+
@overload # float 1d
703+
def zscore(
704+
a: Sequence[float], axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
705+
) -> onp.Array1D[np.float64]: ...
706+
@overload # float 2d
707+
def zscore(
708+
a: Sequence[Sequence[float]], axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
709+
) -> onp.Array2D[np.float64]: ...
710+
@overload # float 3d
711+
def zscore(
712+
a: Sequence[Sequence[Sequence[float]]], axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
713+
) -> onp.Array3D[np.float64]: ...
714+
@overload # complex 1d
715+
def zscore(
716+
a: Sequence[op.JustComplex], axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
717+
) -> onp.Array1D[np.complex128]: ...
718+
@overload # complex 2d
688719
def zscore(
720+
a: Sequence[Sequence[op.JustComplex]], axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
721+
) -> onp.Array2D[np.complex128]: ...
722+
@overload # complex 3d
723+
def zscore(
724+
a: Sequence[Sequence[Sequence[op.JustComplex]]], axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
725+
) -> onp.Array3D[np.complex128]: ...
726+
@overload # floating fallback
727+
def zscore( # the weird shape-type is a workaround for a bug in pyright's overlapping overload detection on numpy<2.1
689728
a: onp.ToFloatND, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
690-
) -> onp.ArrayND[npc.floating]: ...
729+
) -> onp.ArrayND[npc.floating, tuple[int] | tuple[Any, ...]]: ...
730+
@overload # complex fallback
731+
def zscore(
732+
a: onp.ToJustComplexND, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
733+
) -> onp.ArrayND[npc.complexfloating]: ...
691734

692-
# TODO(jorenham): improve
735+
# NOTE: keep in sync with `zscore`
736+
@overload # +integer, known shape
737+
def gzscore(
738+
a: nptc.CanArray[_ShapeT, np.dtype[npc.integer | np.bool_]],
739+
*,
740+
axis: int | None = 0,
741+
ddof: int = 0,
742+
nan_policy: NanPolicy = "propagate",
743+
) -> onp.ArrayND[np.float64, _ShapeT]: ...
744+
@overload # known inexact dtype, known shape
745+
def gzscore(
746+
a: nptc.CanArray[_ShapeT, np.dtype[_InexactT]], *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
747+
) -> onp.ArrayND[_InexactT, _ShapeT]: ...
748+
@overload # float 1d
749+
def gzscore(
750+
a: Sequence[float], *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
751+
) -> onp.Array1D[np.float64]: ...
752+
@overload # float 2d
753+
def gzscore(
754+
a: Sequence[Sequence[float]], *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
755+
) -> onp.Array2D[np.float64]: ...
756+
@overload # float 3d
757+
def gzscore(
758+
a: Sequence[Sequence[Sequence[float]]], *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
759+
) -> onp.Array3D[np.float64]: ...
760+
@overload # complex 1d
693761
def gzscore(
762+
a: Sequence[op.JustComplex], *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
763+
) -> onp.Array1D[np.complex128]: ...
764+
@overload # complex 2d
765+
def gzscore(
766+
a: Sequence[Sequence[op.JustComplex]], *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
767+
) -> onp.Array2D[np.complex128]: ...
768+
@overload # complex 3d
769+
def gzscore(
770+
a: Sequence[Sequence[Sequence[op.JustComplex]]], *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
771+
) -> onp.Array3D[np.complex128]: ...
772+
@overload # floating fallback
773+
def gzscore( # the weird shape-type is a workaround for a bug in pyright's overlapping overload detection on numpy<2.1
694774
a: onp.ToFloatND, *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
695-
) -> onp.ArrayND[npc.floating]: ...
775+
) -> onp.ArrayND[npc.floating, tuple[int] | tuple[Any, ...]]: ...
776+
@overload # complex fallback
777+
def gzscore(
778+
a: onp.ToJustComplexND, *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
779+
) -> onp.ArrayND[npc.complexfloating]: ...
696780

697-
# TODO(jorenham): improve
781+
# TODO(jorenham): improve like zscore
698782
@overload # (real vector-like, real vector-like) -> floating vector
699783
def zmap(
700784
scores: onp.ToFloat1D, compare: onp.ToFloat1D, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
@@ -1568,7 +1652,7 @@ def lmoment(
15681652
sample: onp.ToFloatStrict2D,
15691653
order: _LMomentOrder,
15701654
*,
1571-
axis: L[0, 1, -1, -2] = 0,
1655+
axis: int = 0,
15721656
keepdims: onp.ToFalse = False,
15731657
sorted: op.CanBool = False,
15741658
standardize: op.CanBool = True,
@@ -1579,7 +1663,7 @@ def lmoment(
15791663
sample: onp.ToFloatStrict2D,
15801664
order: _LMomentOrder,
15811665
*,
1582-
axis: L[0, 1, -1, -2] | None = 0,
1666+
axis: int | None = 0,
15831667
keepdims: onp.ToTrue,
15841668
sorted: op.CanBool = False,
15851669
standardize: op.CanBool = True,
@@ -1590,7 +1674,7 @@ def lmoment(
15901674
sample: onp.ToFloatStrict2D,
15911675
order: _LMomentOrder1D | None = None,
15921676
*,
1593-
axis: L[0, 1, -1, -2] = 0,
1677+
axis: int = 0,
15941678
keepdims: onp.ToFalse = False,
15951679
sorted: op.CanBool = False,
15961680
standardize: op.CanBool = True,
@@ -1601,7 +1685,7 @@ def lmoment(
16011685
sample: onp.ToFloatStrict2D,
16021686
order: _LMomentOrder1D | None = None,
16031687
*,
1604-
axis: L[0, 1, -1, -2] | None = 0,
1688+
axis: int | None = 0,
16051689
keepdims: onp.ToTrue,
16061690
sorted: op.CanBool = False,
16071691
standardize: op.CanBool = True,
@@ -1612,7 +1696,7 @@ def lmoment(
16121696
sample: onp.ToFloatStrict3D,
16131697
order: _LMomentOrder,
16141698
*,
1615-
axis: L[0, 1, 2, -1, -2, -3] = 0,
1699+
axis: int = 0,
16161700
keepdims: onp.ToFalse = False,
16171701
sorted: op.CanBool = False,
16181702
standardize: op.CanBool = True,
@@ -1623,7 +1707,7 @@ def lmoment(
16231707
sample: onp.ToFloatStrict3D,
16241708
order: _LMomentOrder,
16251709
*,
1626-
axis: L[0, 1, 2, -1, -2, -3] | None = 0,
1710+
axis: int | None = 0,
16271711
keepdims: onp.ToTrue,
16281712
sorted: op.CanBool = False,
16291713
standardize: op.CanBool = True,
@@ -1634,7 +1718,7 @@ def lmoment(
16341718
sample: onp.ToFloatStrict3D,
16351719
order: _LMomentOrder1D | None = None,
16361720
*,
1637-
axis: L[0, 1, 2, -1, -2, -3] = 0,
1721+
axis: int = 0,
16381722
keepdims: onp.ToFalse = False,
16391723
sorted: op.CanBool = False,
16401724
standardize: op.CanBool = True,
@@ -1645,7 +1729,7 @@ def lmoment(
16451729
sample: onp.ToFloatStrict3D,
16461730
order: _LMomentOrder1D | None = None,
16471731
*,
1648-
axis: L[0, 1, 2, -1, -2, -3] | None = 0,
1732+
axis: int | None = 0,
16491733
keepdims: onp.ToTrue,
16501734
sorted: op.CanBool = False,
16511735
standardize: op.CanBool = True,

tests/stats/test_zscore.pyi

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# type-tests for `zscore` from `stats/_stats_py.pyi`
2+
3+
from typing import assert_type
4+
5+
import numpy as np
6+
import optype.numpy as onp
7+
8+
from scipy.stats import zscore
9+
10+
py_i_1d: list[int]
11+
py_i_2d: list[list[int]]
12+
13+
py_f_1d: list[float]
14+
py_f_2d: list[list[float]]
15+
16+
bool_1d: onp.Array1D[np.bool_]
17+
bool_2d: onp.Array2D[np.bool_]
18+
19+
i16_1d: onp.Array1D[np.int16]
20+
i16_2d: onp.Array2D[np.int16]
21+
22+
f32_1d: onp.Array1D[np.float32]
23+
f32_2d: onp.Array2D[np.float32]
24+
25+
f64_1d: onp.Array1D[np.float64]
26+
f64_2d: onp.Array2D[np.float64]
27+
28+
c64_1d: onp.Array1D[np.complex64]
29+
c64_2d: onp.Array2D[np.complex64]
30+
31+
c128_1d: onp.Array1D[np.complex128]
32+
c128_2d: onp.Array2D[np.complex128]
33+
34+
###
35+
36+
assert_type(zscore(py_i_1d), onp.Array1D[np.float64])
37+
assert_type(zscore(py_f_1d), onp.Array1D[np.float64])
38+
assert_type(zscore(bool_1d), onp.Array1D[np.float64])
39+
assert_type(zscore(i16_1d), onp.Array1D[np.float64])
40+
assert_type(zscore(f32_1d), onp.Array1D[np.float32])
41+
assert_type(zscore(f64_1d), onp.Array1D[np.float64])
42+
assert_type(zscore(c64_1d), onp.Array1D[np.complex64])
43+
assert_type(zscore(c128_1d), onp.Array1D[np.complex128])
44+
45+
assert_type(zscore(py_i_2d), onp.Array2D[np.float64])
46+
assert_type(zscore(py_f_2d), onp.Array2D[np.float64])
47+
assert_type(zscore(bool_2d), onp.Array2D[np.float64])
48+
assert_type(zscore(i16_2d), onp.Array2D[np.float64])
49+
assert_type(zscore(f32_2d), onp.Array2D[np.float32])
50+
assert_type(zscore(f64_2d), onp.Array2D[np.float64])
51+
assert_type(zscore(c64_2d), onp.Array2D[np.complex64])
52+
assert_type(zscore(c128_2d), onp.Array2D[np.complex128])
53+
54+
assert_type(zscore(py_i_1d, axis=None), onp.Array1D[np.float64])
55+
assert_type(zscore(py_f_1d, axis=None), onp.Array1D[np.float64])
56+
assert_type(zscore(bool_1d, axis=None), onp.Array1D[np.float64])
57+
assert_type(zscore(i16_1d, axis=None), onp.Array1D[np.float64])
58+
assert_type(zscore(f32_1d, axis=None), onp.Array1D[np.float32])
59+
assert_type(zscore(f64_1d, axis=None), onp.Array1D[np.float64])
60+
assert_type(zscore(c64_1d, axis=None), onp.Array1D[np.complex64])
61+
assert_type(zscore(c128_1d, axis=None), onp.Array1D[np.complex128])
62+
63+
assert_type(zscore(py_i_2d, axis=None), onp.Array2D[np.float64])
64+
assert_type(zscore(py_f_2d, axis=None), onp.Array2D[np.float64])
65+
assert_type(zscore(bool_2d, axis=None), onp.Array2D[np.float64])
66+
assert_type(zscore(i16_2d, axis=None), onp.Array2D[np.float64])
67+
assert_type(zscore(f32_2d, axis=None), onp.Array2D[np.float32])
68+
assert_type(zscore(f64_2d, axis=None), onp.Array2D[np.float64])
69+
assert_type(zscore(c64_2d, axis=None), onp.Array2D[np.complex64])
70+
assert_type(zscore(c128_2d, axis=None), onp.Array2D[np.complex128])

0 commit comments

Comments
 (0)