Skip to content

Commit 3ac1b7a

Browse files
committed
stats: improve zscore and gzscore
1 parent 593f6c3 commit 3ac1b7a

File tree

1 file changed

+94
-11
lines changed

1 file changed

+94
-11
lines changed

scipy-stubs/stats/_stats_py.pyi

Lines changed: 94 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ __all__ = [
8989

9090
_SCT = TypeVar("_SCT", bound=np.generic)
9191

92+
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
93+
_InexactT = TypeVar("_InexactT", bound=npc.inexact)
9294
_FloatT = TypeVar("_FloatT", bound=npc.floating, default=npc.floating)
9395
_RealT = TypeVar("_RealT", bound=_Real0D, default=_Real0D)
9496
_RealT_co = TypeVar("_RealT_co", bound=_Real0D, default=_Real0D, covariant=True)
@@ -684,17 +686,98 @@ def sem(
684686
a: onp.ToComplexND, axis: int | None = 0, ddof: int = 1, nan_policy: NanPolicy = "propagate", *, keepdims: bool = False
685687
) -> _FloatOrND: ...
686688

687-
# TODO(jorenham): improve
689+
# NOTE: keep in sync with `gzscore`
690+
@overload # +integer, known shape
691+
def zscore(
692+
a: onp.CanArray[_ShapeT, np.dtype[npc.integer | np.bool_]],
693+
axis: int | None = 0,
694+
ddof: int = 0,
695+
nan_policy: NanPolicy = "propagate",
696+
) -> onp.ArrayND[np.float64, _ShapeT]: ...
697+
@overload # known inexact dtype, known shape
698+
def zscore(
699+
a: onp.CanArray[_ShapeT, np.dtype[_InexactT]], axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
700+
) -> onp.ArrayND[_InexactT, _ShapeT]: ...
701+
@overload # float 1d
702+
def zscore(
703+
a: Sequence[float], axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
704+
) -> onp.Array1D[np.float64]: ...
705+
@overload # float 2d
706+
def zscore(
707+
a: Sequence[Sequence[float]], axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
708+
) -> onp.Array2D[np.float64]: ...
709+
@overload # float 3d
710+
def zscore(
711+
a: Sequence[Sequence[Sequence[float]]], axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
712+
) -> onp.Array3D[np.float64]: ...
713+
@overload # complex 1d
714+
def zscore(
715+
a: Sequence[op.JustComplex], axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
716+
) -> onp.Array1D[np.complex128]: ...
717+
@overload # complex 2d
718+
def zscore(
719+
a: Sequence[Sequence[op.JustComplex]], axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
720+
) -> onp.Array2D[np.complex128]: ...
721+
@overload # complex 3d
722+
def zscore(
723+
a: Sequence[Sequence[Sequence[op.JustComplex]]], axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
724+
) -> onp.Array3D[np.complex128]: ...
725+
@overload # floating fallback
688726
def zscore(
689727
a: onp.ToFloatND, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
690728
) -> onp.ArrayND[npc.floating]: ...
729+
@overload # complex fallback
730+
def zscore(
731+
a: onp.ToJustComplexND, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
732+
) -> onp.ArrayND[npc.complexfloating]: ...
691733

692-
# TODO(jorenham): improve
734+
# NOTE: keep in sync with `zscore`
735+
@overload # +integer, known shape
736+
def gzscore(
737+
a: onp.CanArray[_ShapeT, np.dtype[npc.integer | np.bool_]],
738+
*,
739+
axis: int | None = 0,
740+
ddof: int = 0,
741+
nan_policy: NanPolicy = "propagate",
742+
) -> onp.ArrayND[np.float64, _ShapeT]: ...
743+
@overload # known inexact dtype, known shape
744+
def gzscore(
745+
a: onp.CanArray[_ShapeT, np.dtype[_InexactT]], *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
746+
) -> onp.ArrayND[_InexactT, _ShapeT]: ...
747+
@overload # float 1d
748+
def gzscore(
749+
a: Sequence[float], *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
750+
) -> onp.Array1D[np.float64]: ...
751+
@overload # float 2d
752+
def gzscore(
753+
a: Sequence[Sequence[float]], *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
754+
) -> onp.Array2D[np.float64]: ...
755+
@overload # float 3d
756+
def gzscore(
757+
a: Sequence[Sequence[Sequence[float]]], *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
758+
) -> onp.Array3D[np.float64]: ...
759+
@overload # complex 1d
760+
def gzscore(
761+
a: Sequence[op.JustComplex], *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
762+
) -> onp.Array1D[np.complex128]: ...
763+
@overload # complex 2d
764+
def gzscore(
765+
a: Sequence[Sequence[op.JustComplex]], *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
766+
) -> onp.Array2D[np.complex128]: ...
767+
@overload # complex 3d
768+
def gzscore(
769+
a: Sequence[Sequence[Sequence[op.JustComplex]]], *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
770+
) -> onp.Array3D[np.complex128]: ...
771+
@overload # floating fallback
693772
def gzscore(
694773
a: onp.ToFloatND, *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
695774
) -> onp.ArrayND[npc.floating]: ...
775+
@overload # complex fallback
776+
def gzscore(
777+
a: onp.ToJustComplexND, *, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
778+
) -> onp.ArrayND[npc.complexfloating]: ...
696779

697-
# TODO(jorenham): improve
780+
# TODO(jorenham): improve like zscore
698781
@overload # (real vector-like, real vector-like) -> floating vector
699782
def zmap(
700783
scores: onp.ToFloat1D, compare: onp.ToFloat1D, axis: int | None = 0, ddof: int = 0, nan_policy: NanPolicy = "propagate"
@@ -1568,7 +1651,7 @@ def lmoment(
15681651
sample: onp.ToFloatStrict2D,
15691652
order: _LMomentOrder,
15701653
*,
1571-
axis: L[0, 1, -1, -2] = 0,
1654+
axis: int = 0,
15721655
keepdims: onp.ToFalse = False,
15731656
sorted: op.CanBool = False,
15741657
standardize: op.CanBool = True,
@@ -1579,7 +1662,7 @@ def lmoment(
15791662
sample: onp.ToFloatStrict2D,
15801663
order: _LMomentOrder,
15811664
*,
1582-
axis: L[0, 1, -1, -2] | None = 0,
1665+
axis: int | None = 0,
15831666
keepdims: onp.ToTrue,
15841667
sorted: op.CanBool = False,
15851668
standardize: op.CanBool = True,
@@ -1590,7 +1673,7 @@ def lmoment(
15901673
sample: onp.ToFloatStrict2D,
15911674
order: _LMomentOrder1D | None = None,
15921675
*,
1593-
axis: L[0, 1, -1, -2] = 0,
1676+
axis: int = 0,
15941677
keepdims: onp.ToFalse = False,
15951678
sorted: op.CanBool = False,
15961679
standardize: op.CanBool = True,
@@ -1601,7 +1684,7 @@ def lmoment(
16011684
sample: onp.ToFloatStrict2D,
16021685
order: _LMomentOrder1D | None = None,
16031686
*,
1604-
axis: L[0, 1, -1, -2] | None = 0,
1687+
axis: int | None = 0,
16051688
keepdims: onp.ToTrue,
16061689
sorted: op.CanBool = False,
16071690
standardize: op.CanBool = True,
@@ -1612,7 +1695,7 @@ def lmoment(
16121695
sample: onp.ToFloatStrict3D,
16131696
order: _LMomentOrder,
16141697
*,
1615-
axis: L[0, 1, 2, -1, -2, -3] = 0,
1698+
axis: int = 0,
16161699
keepdims: onp.ToFalse = False,
16171700
sorted: op.CanBool = False,
16181701
standardize: op.CanBool = True,
@@ -1623,7 +1706,7 @@ def lmoment(
16231706
sample: onp.ToFloatStrict3D,
16241707
order: _LMomentOrder,
16251708
*,
1626-
axis: L[0, 1, 2, -1, -2, -3] | None = 0,
1709+
axis: int | None = 0,
16271710
keepdims: onp.ToTrue,
16281711
sorted: op.CanBool = False,
16291712
standardize: op.CanBool = True,
@@ -1634,7 +1717,7 @@ def lmoment(
16341717
sample: onp.ToFloatStrict3D,
16351718
order: _LMomentOrder1D | None = None,
16361719
*,
1637-
axis: L[0, 1, 2, -1, -2, -3] = 0,
1720+
axis: int = 0,
16381721
keepdims: onp.ToFalse = False,
16391722
sorted: op.CanBool = False,
16401723
standardize: op.CanBool = True,
@@ -1645,7 +1728,7 @@ def lmoment(
16451728
sample: onp.ToFloatStrict3D,
16461729
order: _LMomentOrder1D | None = None,
16471730
*,
1648-
axis: L[0, 1, 2, -1, -2, -3] | None = 0,
1731+
axis: int | None = 0,
16491732
keepdims: onp.ToTrue,
16501733
sorted: op.CanBool = False,
16511734
standardize: op.CanBool = True,

0 commit comments

Comments
 (0)