@@ -6,6 +6,7 @@ from typing_extensions import NamedTuple, TypeVar, deprecated
66
77import numpy as np
88import numpy .typing as npt
9+ import numpy_typing_compat as nptc
910import optype as op
1011import optype .numpy as onp
1112import optype .numpy .compat as npc
@@ -89,6 +90,8 @@ __all__ = [
8990
9091_SCT = TypeVar ("_SCT" , bound = np .generic )
9192
93+ _ShapeT = TypeVar ("_ShapeT" , bound = tuple [int , ...])
94+ _InexactT = TypeVar ("_InexactT" , bound = npc .inexact )
9295_FloatT = TypeVar ("_FloatT" , bound = npc .floating , default = npc .floating )
9396_RealT = TypeVar ("_RealT" , bound = _Real0D , default = _Real0D )
9497_RealT_co = TypeVar ("_RealT_co" , bound = _Real0D , default = _Real0D , covariant = True )
@@ -684,17 +687,98 @@ def sem(
684687 a : onp .ToComplexND , axis : int | None = 0 , ddof : int = 1 , nan_policy : NanPolicy = "propagate" , * , keepdims : bool = False
685688) -> _FloatOrND : ...
686689
687- # TODO(jorenham): improve
690+ # NOTE: keep in sync with `gzscore`
691+ @overload # +integer, known shape
692+ def zscore (
693+ a : nptc .CanArray [_ShapeT , np .dtype [npc .integer | np .bool_ ]],
694+ axis : int | None = 0 ,
695+ ddof : int = 0 ,
696+ nan_policy : NanPolicy = "propagate" ,
697+ ) -> onp .ArrayND [np .float64 , _ShapeT ]: ...
698+ @overload # known inexact dtype, known shape
699+ def zscore (
700+ a : nptc .CanArray [_ShapeT , np .dtype [_InexactT ]], axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
701+ ) -> onp .ArrayND [_InexactT , _ShapeT ]: ...
702+ @overload # float 1d
703+ def zscore (
704+ a : Sequence [float ], axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
705+ ) -> onp .Array1D [np .float64 ]: ...
706+ @overload # float 2d
707+ def zscore (
708+ a : Sequence [Sequence [float ]], axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
709+ ) -> onp .Array2D [np .float64 ]: ...
710+ @overload # float 3d
711+ def zscore (
712+ a : Sequence [Sequence [Sequence [float ]]], axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
713+ ) -> onp .Array3D [np .float64 ]: ...
714+ @overload # complex 1d
715+ def zscore (
716+ a : Sequence [op .JustComplex ], axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
717+ ) -> onp .Array1D [np .complex128 ]: ...
718+ @overload # complex 2d
688719def zscore (
720+ a : Sequence [Sequence [op .JustComplex ]], axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
721+ ) -> onp .Array2D [np .complex128 ]: ...
722+ @overload # complex 3d
723+ def zscore (
724+ a : Sequence [Sequence [Sequence [op .JustComplex ]]], axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
725+ ) -> onp .Array3D [np .complex128 ]: ...
726+ @overload # floating fallback
727+ def zscore ( # the weird shape-type is a workaround for a bug in pyright's overlapping overload detection on numpy<2.1
689728 a : onp .ToFloatND , axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
690- ) -> onp .ArrayND [npc .floating ]: ...
729+ ) -> onp .ArrayND [npc .floating , tuple [int ] | tuple [Any , ...]]: ...
730+ @overload # complex fallback
731+ def zscore (
732+ a : onp .ToJustComplexND , axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
733+ ) -> onp .ArrayND [npc .complexfloating ]: ...
691734
692- # TODO(jorenham): improve
735+ # NOTE: keep in sync with `zscore`
736+ @overload # +integer, known shape
737+ def gzscore (
738+ a : nptc .CanArray [_ShapeT , np .dtype [npc .integer | np .bool_ ]],
739+ * ,
740+ axis : int | None = 0 ,
741+ ddof : int = 0 ,
742+ nan_policy : NanPolicy = "propagate" ,
743+ ) -> onp .ArrayND [np .float64 , _ShapeT ]: ...
744+ @overload # known inexact dtype, known shape
745+ def gzscore (
746+ a : nptc .CanArray [_ShapeT , np .dtype [_InexactT ]], * , axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
747+ ) -> onp .ArrayND [_InexactT , _ShapeT ]: ...
748+ @overload # float 1d
749+ def gzscore (
750+ a : Sequence [float ], * , axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
751+ ) -> onp .Array1D [np .float64 ]: ...
752+ @overload # float 2d
753+ def gzscore (
754+ a : Sequence [Sequence [float ]], * , axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
755+ ) -> onp .Array2D [np .float64 ]: ...
756+ @overload # float 3d
757+ def gzscore (
758+ a : Sequence [Sequence [Sequence [float ]]], * , axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
759+ ) -> onp .Array3D [np .float64 ]: ...
760+ @overload # complex 1d
693761def gzscore (
762+ a : Sequence [op .JustComplex ], * , axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
763+ ) -> onp .Array1D [np .complex128 ]: ...
764+ @overload # complex 2d
765+ def gzscore (
766+ a : Sequence [Sequence [op .JustComplex ]], * , axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
767+ ) -> onp .Array2D [np .complex128 ]: ...
768+ @overload # complex 3d
769+ def gzscore (
770+ a : Sequence [Sequence [Sequence [op .JustComplex ]]], * , axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
771+ ) -> onp .Array3D [np .complex128 ]: ...
772+ @overload # floating fallback
773+ def gzscore ( # the weird shape-type is a workaround for a bug in pyright's overlapping overload detection on numpy<2.1
694774 a : onp .ToFloatND , * , axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
695- ) -> onp .ArrayND [npc .floating ]: ...
775+ ) -> onp .ArrayND [npc .floating , tuple [int ] | tuple [Any , ...]]: ...
776+ @overload # complex fallback
777+ def gzscore (
778+ a : onp .ToJustComplexND , * , axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
779+ ) -> onp .ArrayND [npc .complexfloating ]: ...
696780
697- # TODO(jorenham): improve
781+ # TODO(jorenham): improve like zscore
698782@overload # (real vector-like, real vector-like) -> floating vector
699783def zmap (
700784 scores : onp .ToFloat1D , compare : onp .ToFloat1D , axis : int | None = 0 , ddof : int = 0 , nan_policy : NanPolicy = "propagate"
@@ -1568,7 +1652,7 @@ def lmoment(
15681652 sample : onp .ToFloatStrict2D ,
15691653 order : _LMomentOrder ,
15701654 * ,
1571- axis : L [ 0 , 1 , - 1 , - 2 ] = 0 ,
1655+ axis : int = 0 ,
15721656 keepdims : onp .ToFalse = False ,
15731657 sorted : op .CanBool = False ,
15741658 standardize : op .CanBool = True ,
@@ -1579,7 +1663,7 @@ def lmoment(
15791663 sample : onp .ToFloatStrict2D ,
15801664 order : _LMomentOrder ,
15811665 * ,
1582- axis : L [ 0 , 1 , - 1 , - 2 ] | None = 0 ,
1666+ axis : int | None = 0 ,
15831667 keepdims : onp .ToTrue ,
15841668 sorted : op .CanBool = False ,
15851669 standardize : op .CanBool = True ,
@@ -1590,7 +1674,7 @@ def lmoment(
15901674 sample : onp .ToFloatStrict2D ,
15911675 order : _LMomentOrder1D | None = None ,
15921676 * ,
1593- axis : L [ 0 , 1 , - 1 , - 2 ] = 0 ,
1677+ axis : int = 0 ,
15941678 keepdims : onp .ToFalse = False ,
15951679 sorted : op .CanBool = False ,
15961680 standardize : op .CanBool = True ,
@@ -1601,7 +1685,7 @@ def lmoment(
16011685 sample : onp .ToFloatStrict2D ,
16021686 order : _LMomentOrder1D | None = None ,
16031687 * ,
1604- axis : L [ 0 , 1 , - 1 , - 2 ] | None = 0 ,
1688+ axis : int | None = 0 ,
16051689 keepdims : onp .ToTrue ,
16061690 sorted : op .CanBool = False ,
16071691 standardize : op .CanBool = True ,
@@ -1612,7 +1696,7 @@ def lmoment(
16121696 sample : onp .ToFloatStrict3D ,
16131697 order : _LMomentOrder ,
16141698 * ,
1615- axis : L [ 0 , 1 , 2 , - 1 , - 2 , - 3 ] = 0 ,
1699+ axis : int = 0 ,
16161700 keepdims : onp .ToFalse = False ,
16171701 sorted : op .CanBool = False ,
16181702 standardize : op .CanBool = True ,
@@ -1623,7 +1707,7 @@ def lmoment(
16231707 sample : onp .ToFloatStrict3D ,
16241708 order : _LMomentOrder ,
16251709 * ,
1626- axis : L [ 0 , 1 , 2 , - 1 , - 2 , - 3 ] | None = 0 ,
1710+ axis : int | None = 0 ,
16271711 keepdims : onp .ToTrue ,
16281712 sorted : op .CanBool = False ,
16291713 standardize : op .CanBool = True ,
@@ -1634,7 +1718,7 @@ def lmoment(
16341718 sample : onp .ToFloatStrict3D ,
16351719 order : _LMomentOrder1D | None = None ,
16361720 * ,
1637- axis : L [ 0 , 1 , 2 , - 1 , - 2 , - 3 ] = 0 ,
1721+ axis : int = 0 ,
16381722 keepdims : onp .ToFalse = False ,
16391723 sorted : op .CanBool = False ,
16401724 standardize : op .CanBool = True ,
@@ -1645,7 +1729,7 @@ def lmoment(
16451729 sample : onp .ToFloatStrict3D ,
16461730 order : _LMomentOrder1D | None = None ,
16471731 * ,
1648- axis : L [ 0 , 1 , 2 , - 1 , - 2 , - 3 ] | None = 0 ,
1732+ axis : int | None = 0 ,
16491733 keepdims : onp .ToTrue ,
16501734 sorted : op .CanBool = False ,
16511735 standardize : op .CanBool = True ,
0 commit comments