Skip to content

Commit 4b16e52

Browse files
authored
🐛 stats: fix rv_discrete sample constructor (#418)
2 parents 5a48e4e + 136e09b commit 4b16e52

File tree

2 files changed

+61
-15
lines changed

2 files changed

+61
-15
lines changed

scipy-stubs/stats/_distn_infrastructure.pyi

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ from typing_extensions import Self, TypeVar, Unpack, override
1111
import numpy as np
1212
import optype as op
1313
import optype.numpy as onp
14+
import optype.numpy.compat as npc
1415
from scipy._typing import RNG, AnyShape, Falsy, ToRNG, Truthy
1516
from scipy.integrate._typing import QuadOpts as _QuadOpts
1617

@@ -24,16 +25,15 @@ _RVT = TypeVar("_RVT", bound=rv_generic, default=rv_generic)
2425
_RVT_co = TypeVar("_RVT_co", bound=rv_generic, default=rv_generic, covariant=True)
2526
_CRVT_co = TypeVar("_CRVT_co", bound=rv_continuous, default=rv_continuous, covariant=True)
2627
_DRVT_co = TypeVar("_DRVT_co", bound=rv_discrete, default=rv_discrete, covariant=True)
27-
_XKT_co = TypeVar("_XKT_co", bound=np.number[Any], covariant=True, default=np.number[Any])
28+
_XKT_co = TypeVar("_XKT_co", bound=_CoFloat, covariant=True, default=_CoFloat)
2829
_PKT_co = TypeVar("_PKT_co", bound=_Floating, covariant=True, default=_Floating)
2930

3031
_Tuple2: TypeAlias = tuple[_T, _T]
3132
_Tuple3: TypeAlias = tuple[_T, _T, _T]
3233
_Tuple4: TypeAlias = tuple[_T, _T, _T, _T]
3334

34-
_Integer: TypeAlias = np.integer[Any]
3535
_Floating: TypeAlias = np.float64 | np.float32 | np.float16 # longdouble often results in trouble
36-
_CoFloat: TypeAlias = _Floating | _Integer
36+
_CoFloat: TypeAlias = _Floating | npc.integer
3737

3838
_Bool: TypeAlias = bool | np.bool_
3939
_Int: TypeAlias = int | np.int32 | np.int64
@@ -841,42 +841,76 @@ class rv_discrete(_rv_mixin, rv_generic):
841841
inc: Final[int]
842842
moment_tol: Final[float]
843843

844+
@overload
844845
def __new__(
845846
cls,
846847
a: onp.ToFloat = 0,
847848
b: onp.ToFloat = ...,
848849
name: str | None = None,
849850
badvalue: _Float | None = None,
850-
moment_tol: _Float = 1e-08,
851-
values: _Tuple2[_ToFloatOrND] | None = None,
851+
moment_tol: _Float = 1e-8,
852+
values: None = None,
852853
inc: int | np.int_ = 1,
853854
longname: str | None = None,
854855
shapes: str | None = None,
855856
seed: ToRNG = None,
856857
) -> Self: ...
857-
def __init__( # pyright: ignore[reportInconsistentConstructor]
858+
# NOTE: The return types of the following overloads is ignored by mypy
859+
@overload
860+
def __new__(
861+
cls,
862+
a: onp.ToFloat,
863+
b: onp.ToFloat,
864+
name: str | None,
865+
badvalue: _Float | None,
866+
moment_tol: _Float,
867+
values: _Tuple2[onp.ToFloatND],
868+
inc: int | np.int_ = 1,
869+
longname: str | None = None,
870+
shapes: str | None = None,
871+
seed: ToRNG = None,
872+
) -> rv_sample: ...
873+
@overload
874+
def __new__(
875+
cls,
876+
a: onp.ToFloat = 0,
877+
b: onp.ToFloat = ...,
878+
name: str | None = None,
879+
badvalue: _Float | None = None,
880+
moment_tol: _Float = 1e-8,
881+
*,
882+
values: _Tuple2[onp.ToFloatND],
883+
inc: int | np.int_ = 1,
884+
longname: str | None = None,
885+
shapes: str | None = None,
886+
seed: ToRNG = None,
887+
) -> rv_sample: ...
888+
889+
#
890+
def __init__(
858891
self,
859892
/,
860893
a: onp.ToFloat = 0,
861894
b: onp.ToFloat = ...,
862895
name: str | None = None,
863896
badvalue: _Float | None = None,
864-
moment_tol: _Float = 1e-08,
865-
values: None = None,
897+
moment_tol: _Float = 1e-8,
898+
# mypy workaround: `values` can only be None
899+
values: _Tuple2[onp.ToFloatND] | None = None,
866900
inc: int | np.int_ = 1,
867901
longname: str | None = None,
868902
shapes: str | None = None,
869903
seed: ToRNG = None,
870904
) -> None: ...
871905

872-
#
873906
# NOTE: Using `@override` on `__call__` or `freeze` causes stubtest to crash (mypy 1.11.1)
874907
@overload
875908
def __call__(self, /) -> rv_discrete_frozen[Self, _Float]: ...
876909
@overload
877910
def __call__(self, /, *args: onp.ToFloat, loc: onp.ToFloat = 0, **kwds: onp.ToFloat) -> rv_discrete_frozen[Self, _Float]: ...
878911
@overload
879912
def __call__(self, /, *args: _ToFloatOrND, loc: _ToFloatOrND = 0, **kwds: _ToFloatOrND) -> rv_discrete_frozen[Self]: ...
913+
880914
#
881915
@overload
882916
def freeze(self, /) -> rv_discrete_frozen[Self, _Float]: ...
@@ -1028,22 +1062,23 @@ class rv_discrete(_rv_mixin, rv_generic):
10281062
**kwds: _ToFloatOrND,
10291063
) -> _IntOrND: ...
10301064

1031-
# undocumented
1065+
# returned by `rv_discrete.__new__` if `values` is specified
10321066
class rv_sample(rv_discrete, Generic[_XKT_co, _PKT_co]):
10331067
xk: onp.Array1D[_XKT_co]
10341068
pk: onp.Array1D[_PKT_co]
10351069
qvals: onp.Array1D[_PKT_co]
10361070

1037-
def __init__( # pyright: ignore[reportInconsistentConstructor]
1071+
def __init__(
10381072
self,
10391073
/,
10401074
a: onp.ToFloat = 0,
10411075
b: onp.ToFloat = ...,
10421076
name: str | None = None,
1043-
badvalue: float | None = None,
1044-
moment_tol: float = 1e-08,
1045-
values: tuple[_ToFloatOrND, _ToFloatOrND] | None = None,
1046-
inc: int = 1,
1077+
badvalue: _Float | None = None,
1078+
moment_tol: _Float = 1e-8,
1079+
# never None in practice, but required by stubtest
1080+
values: _Tuple2[onp.ToFloatND] | None = None,
1081+
inc: int | np.int_ = 1,
10471082
longname: str | None = None,
10481083
shapes: str | None = None,
10491084
seed: ToRNG = None,

tests/stats/test_rv_sample.pyi

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from typing_extensions import assert_type
2+
3+
import numpy as np
4+
import optype.numpy as onp
5+
from scipy.stats._distn_infrastructure import rv_discrete, rv_sample
6+
7+
xk: onp.Array1D[np.int_]
8+
pk: tuple[float, ...]
9+
10+
# mypy fails because it (still) doesn't support __new__ returning something that isn't `Self`
11+
assert_type(rv_discrete(values=(xk, pk)), rv_sample) # type: ignore[assert-type]

0 commit comments

Comments
 (0)