@@ -11,6 +11,7 @@ from typing_extensions import Self, TypeVar, Unpack, override
11
11
import numpy as np
12
12
import optype as op
13
13
import optype .numpy as onp
14
+ import optype .numpy .compat as npc
14
15
from scipy ._typing import RNG , AnyShape , Falsy , ToRNG , Truthy
15
16
from scipy .integrate ._typing import QuadOpts as _QuadOpts
16
17
@@ -24,16 +25,15 @@ _RVT = TypeVar("_RVT", bound=rv_generic, default=rv_generic)
24
25
_RVT_co = TypeVar ("_RVT_co" , bound = rv_generic , default = rv_generic , covariant = True )
25
26
_CRVT_co = TypeVar ("_CRVT_co" , bound = rv_continuous , default = rv_continuous , covariant = True )
26
27
_DRVT_co = TypeVar ("_DRVT_co" , bound = rv_discrete , default = rv_discrete , covariant = True )
27
- _XKT_co = TypeVar ("_XKT_co" , bound = np . number [ Any ] , covariant = True , default = np . number [ Any ] )
28
+ _XKT_co = TypeVar ("_XKT_co" , bound = _CoFloat , covariant = True , default = _CoFloat )
28
29
_PKT_co = TypeVar ("_PKT_co" , bound = _Floating , covariant = True , default = _Floating )
29
30
30
31
_Tuple2 : TypeAlias = tuple [_T , _T ]
31
32
_Tuple3 : TypeAlias = tuple [_T , _T , _T ]
32
33
_Tuple4 : TypeAlias = tuple [_T , _T , _T , _T ]
33
34
34
- _Integer : TypeAlias = np .integer [Any ]
35
35
_Floating : TypeAlias = np .float64 | np .float32 | np .float16 # longdouble often results in trouble
36
- _CoFloat : TypeAlias = _Floating | _Integer
36
+ _CoFloat : TypeAlias = _Floating | npc . integer
37
37
38
38
_Bool : TypeAlias = bool | np .bool_
39
39
_Int : TypeAlias = int | np .int32 | np .int64
@@ -841,42 +841,76 @@ class rv_discrete(_rv_mixin, rv_generic):
841
841
inc : Final [int ]
842
842
moment_tol : Final [float ]
843
843
844
+ @overload
844
845
def __new__ (
845
846
cls ,
846
847
a : onp .ToFloat = 0 ,
847
848
b : onp .ToFloat = ...,
848
849
name : str | None = None ,
849
850
badvalue : _Float | None = None ,
850
- moment_tol : _Float = 1e-08 ,
851
- values : _Tuple2 [ _ToFloatOrND ] | None = None ,
851
+ moment_tol : _Float = 1e-8 ,
852
+ values : None = None ,
852
853
inc : int | np .int_ = 1 ,
853
854
longname : str | None = None ,
854
855
shapes : str | None = None ,
855
856
seed : ToRNG = None ,
856
857
) -> Self : ...
857
- def __init__ ( # pyright: ignore[reportInconsistentConstructor]
858
+ # NOTE: The return types of the following overloads is ignored by mypy
859
+ @overload
860
+ def __new__ (
861
+ cls ,
862
+ a : onp .ToFloat ,
863
+ b : onp .ToFloat ,
864
+ name : str | None ,
865
+ badvalue : _Float | None ,
866
+ moment_tol : _Float ,
867
+ values : _Tuple2 [onp .ToFloatND ],
868
+ inc : int | np .int_ = 1 ,
869
+ longname : str | None = None ,
870
+ shapes : str | None = None ,
871
+ seed : ToRNG = None ,
872
+ ) -> rv_sample : ...
873
+ @overload
874
+ def __new__ (
875
+ cls ,
876
+ a : onp .ToFloat = 0 ,
877
+ b : onp .ToFloat = ...,
878
+ name : str | None = None ,
879
+ badvalue : _Float | None = None ,
880
+ moment_tol : _Float = 1e-8 ,
881
+ * ,
882
+ values : _Tuple2 [onp .ToFloatND ],
883
+ inc : int | np .int_ = 1 ,
884
+ longname : str | None = None ,
885
+ shapes : str | None = None ,
886
+ seed : ToRNG = None ,
887
+ ) -> rv_sample : ...
888
+
889
+ #
890
+ def __init__ (
858
891
self ,
859
892
/ ,
860
893
a : onp .ToFloat = 0 ,
861
894
b : onp .ToFloat = ...,
862
895
name : str | None = None ,
863
896
badvalue : _Float | None = None ,
864
- moment_tol : _Float = 1e-08 ,
865
- values : None = None ,
897
+ moment_tol : _Float = 1e-8 ,
898
+ # mypy workaround: `values` can only be None
899
+ values : _Tuple2 [onp .ToFloatND ] | None = None ,
866
900
inc : int | np .int_ = 1 ,
867
901
longname : str | None = None ,
868
902
shapes : str | None = None ,
869
903
seed : ToRNG = None ,
870
904
) -> None : ...
871
905
872
- #
873
906
# NOTE: Using `@override` on `__call__` or `freeze` causes stubtest to crash (mypy 1.11.1)
874
907
@overload
875
908
def __call__ (self , / ) -> rv_discrete_frozen [Self , _Float ]: ...
876
909
@overload
877
910
def __call__ (self , / , * args : onp .ToFloat , loc : onp .ToFloat = 0 , ** kwds : onp .ToFloat ) -> rv_discrete_frozen [Self , _Float ]: ...
878
911
@overload
879
912
def __call__ (self , / , * args : _ToFloatOrND , loc : _ToFloatOrND = 0 , ** kwds : _ToFloatOrND ) -> rv_discrete_frozen [Self ]: ...
913
+
880
914
#
881
915
@overload
882
916
def freeze (self , / ) -> rv_discrete_frozen [Self , _Float ]: ...
@@ -1028,22 +1062,23 @@ class rv_discrete(_rv_mixin, rv_generic):
1028
1062
** kwds : _ToFloatOrND ,
1029
1063
) -> _IntOrND : ...
1030
1064
1031
- # undocumented
1065
+ # returned by `rv_discrete.__new__` if `values` is specified
1032
1066
class rv_sample (rv_discrete , Generic [_XKT_co , _PKT_co ]):
1033
1067
xk : onp .Array1D [_XKT_co ]
1034
1068
pk : onp .Array1D [_PKT_co ]
1035
1069
qvals : onp .Array1D [_PKT_co ]
1036
1070
1037
- def __init__ ( # pyright: ignore[reportInconsistentConstructor]
1071
+ def __init__ (
1038
1072
self ,
1039
1073
/ ,
1040
1074
a : onp .ToFloat = 0 ,
1041
1075
b : onp .ToFloat = ...,
1042
1076
name : str | None = None ,
1043
- badvalue : float | None = None ,
1044
- moment_tol : float = 1e-08 ,
1045
- values : tuple [_ToFloatOrND , _ToFloatOrND ] | None = None ,
1046
- inc : int = 1 ,
1077
+ badvalue : _Float | None = None ,
1078
+ moment_tol : _Float = 1e-8 ,
1079
+ # never None in practice, but required by stubtest
1080
+ values : _Tuple2 [onp .ToFloatND ] | None = None ,
1081
+ inc : int | np .int_ = 1 ,
1047
1082
longname : str | None = None ,
1048
1083
shapes : str | None = None ,
1049
1084
seed : ToRNG = None ,
0 commit comments