Skip to content

Commit e396b6f

Browse files
authored
stats: improved Covariance generic type parameter inference (#739)
2 parents 94d3f4c + 67162f9 commit e396b6f

File tree

1 file changed

+24
-22
lines changed

1 file changed

+24
-22
lines changed

scipy-stubs/stats/_covariance.pyi

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from collections.abc import Sequence
2-
from typing import Final, Generic, Protocol, TypeAlias, overload, type_check_only
1+
from typing import Final, Generic, Protocol, overload, type_check_only
32
from typing_extensions import TypeVar
43

54
import numpy as np
@@ -8,46 +7,49 @@ import optype.numpy.compat as npc
87

98
__all__ = ["Covariance"]
109

11-
# `float16` and `longdouble` aren't supported in `scipy.linalg`, and neither is `bool_`
12-
_Scalar_uif: TypeAlias = np.float32 | np.float64 | npc.integer
10+
_ScalarT = TypeVar("_ScalarT", bound=npc.floating | npc.integer)
11+
_ScalarT_co = TypeVar("_ScalarT_co", bound=npc.floating | npc.integer, default=np.float64, covariant=True)
1312

14-
_SCT = TypeVar("_SCT", bound=_Scalar_uif)
15-
_SCT_co = TypeVar("_SCT_co", bound=_Scalar_uif, covariant=True, default=np.float64)
16-
17-
class Covariance(Generic[_SCT_co]):
13+
class Covariance(Generic[_ScalarT_co]):
1814
@staticmethod
1915
@overload
20-
def from_diagonal(diagonal: Sequence[int]) -> CovViaDiagonal[np.int_]: ...
16+
def from_diagonal(diagonal: onp.ToJustFloat64_1D) -> CovViaDiagonal[np.float64]: ...
2117
@staticmethod
2218
@overload
23-
def from_diagonal(diagonal: Sequence[float]) -> CovViaDiagonal[np.int_ | np.float64]: ...
19+
def from_diagonal(diagonal: onp.ToJustInt64_1D) -> CovViaDiagonal[np.int_]: ...
2420
@staticmethod
2521
@overload
26-
def from_diagonal(diagonal: Sequence[_SCT] | onp.CanArrayND[_SCT]) -> CovViaDiagonal[_SCT]: ...
22+
def from_diagonal(diagonal: onp.ToArray1D[_ScalarT, _ScalarT]) -> CovViaDiagonal[_ScalarT]: ...
23+
24+
#
2725
@staticmethod
2826
def from_precision(precision: onp.ToFloat2D, covariance: onp.ToFloat2D | None = None) -> CovViaPrecision: ...
2927
@staticmethod
3028
def from_cholesky(cholesky: onp.ToFloat2D) -> CovViaCholesky: ...
3129
@staticmethod
3230
def from_eigendecomposition(eigendecomposition: tuple[onp.ToFloat1D, onp.ToFloat2D]) -> CovViaEigendecomposition: ...
33-
def whiten(self, /, x: onp.AnyIntegerArray | onp.AnyFloatingArray) -> onp.ArrayND[npc.floating]: ...
34-
def colorize(self, /, x: onp.AnyIntegerArray | onp.AnyFloatingArray) -> onp.ArrayND[npc.floating]: ...
31+
32+
#
3533
@property
3634
def log_pdet(self, /) -> np.float64: ...
3735
@property
3836
def rank(self, /) -> np.int_: ...
3937
@property
40-
def covariance(self, /) -> onp.Array2D[_SCT_co]: ...
38+
def covariance(self, /) -> onp.Array2D[_ScalarT_co]: ...
4139
@property
4240
def shape(self, /) -> tuple[int, int]: ...
4341

44-
class CovViaDiagonal(Covariance[_SCT_co], Generic[_SCT_co]):
42+
#
43+
def whiten(self, /, x: onp.ToFloatND) -> onp.ArrayND[npc.floating]: ...
44+
def colorize(self, /, x: onp.ToFloatND) -> onp.ArrayND[npc.floating]: ...
45+
46+
class CovViaDiagonal(Covariance[_ScalarT_co], Generic[_ScalarT_co]):
4547
@overload
46-
def __init__(self: CovViaDiagonal[np.int_], /, diagonal: Sequence[int]) -> None: ...
48+
def __init__(self: CovViaDiagonal[np.float64], /, diagonal: onp.ToJustFloat64_1D) -> None: ...
4749
@overload
48-
def __init__(self: CovViaDiagonal[np.int_ | np.float64], /, diagonal: Sequence[float]) -> None: ...
50+
def __init__(self: CovViaDiagonal[np.int_], /, diagonal: onp.ToJustInt64_1D) -> None: ...
4951
@overload
50-
def __init__(self, /, diagonal: Sequence[float | _SCT_co] | onp.CanArrayND[_SCT_co]) -> None: ...
52+
def __init__(self, /, diagonal: onp.ToArray1D[_ScalarT_co, _ScalarT_co]) -> None: ...
5153

5254
class CovViaPrecision(Covariance[np.float64]):
5355
def __init__(self, /, precision: onp.ToFloat2D, covariance: onp.ToFloat2D | None = None) -> None: ...
@@ -63,17 +65,17 @@ class _PSD(Protocol):
6365
_M: onp.ArrayND[np.float64]
6466
V: onp.ArrayND[np.float64]
6567
U: onp.ArrayND[np.float64]
66-
eps: np.float64 | float
67-
log_pdet: np.float64 | float
68-
cond: np.float64 | float
68+
eps: float
69+
log_pdet: float
70+
cond: float
6971
rank: int
7072

7173
@property
7274
def pinv(self, /) -> onp.ArrayND[npc.floating]: ...
7375

7476
class CovViaPSD(Covariance[np.float64]):
7577
_LP: Final[onp.ArrayND[np.float64]]
76-
_log_pdet: Final[np.float64 | float]
78+
_log_pdet: Final[float]
7779
_rank: Final[int]
7880
_covariance: Final[onp.ArrayND[np.float64]]
7981
_shape: tuple[int, int]

0 commit comments

Comments
 (0)