1
- from collections .abc import Sequence
2
- from typing import Final , Generic , Protocol , TypeAlias , overload , type_check_only
1
+ from typing import Final , Generic , Protocol , overload , type_check_only
3
2
from typing_extensions import TypeVar
4
3
5
4
import numpy as np
@@ -8,46 +7,49 @@ import optype.numpy.compat as npc
8
7
9
8
__all__ = ["Covariance" ]
10
9
11
- # `float16` and `longdouble` aren't supported in `scipy.linalg`, and neither is `bool_`
12
- _Scalar_uif : TypeAlias = np . float32 | np . float64 | npc .integer
10
+ _ScalarT = TypeVar ( "_ScalarT" , bound = npc . floating | npc . integer )
11
+ _ScalarT_co = TypeVar ( "_ScalarT_co" , bound = npc . floating | npc .integer , default = np . float64 , covariant = True )
13
12
14
- _SCT = TypeVar ("_SCT" , bound = _Scalar_uif )
15
- _SCT_co = TypeVar ("_SCT_co" , bound = _Scalar_uif , covariant = True , default = np .float64 )
16
-
17
- class Covariance (Generic [_SCT_co ]):
13
+ class Covariance (Generic [_ScalarT_co ]):
18
14
@staticmethod
19
15
@overload
20
- def from_diagonal (diagonal : Sequence [ int ] ) -> CovViaDiagonal [np .int_ ]: ...
16
+ def from_diagonal (diagonal : onp . ToJustFloat64_1D ) -> CovViaDiagonal [np .float64 ]: ...
21
17
@staticmethod
22
18
@overload
23
- def from_diagonal (diagonal : Sequence [ float ] ) -> CovViaDiagonal [np .int_ | np . float64 ]: ...
19
+ def from_diagonal (diagonal : onp . ToJustInt64_1D ) -> CovViaDiagonal [np .int_ ]: ...
24
20
@staticmethod
25
21
@overload
26
- def from_diagonal (diagonal : Sequence [_SCT ] | onp .CanArrayND [_SCT ]) -> CovViaDiagonal [_SCT ]: ...
22
+ def from_diagonal (diagonal : onp .ToArray1D [_ScalarT , _ScalarT ]) -> CovViaDiagonal [_ScalarT ]: ...
23
+
24
+ #
27
25
@staticmethod
28
26
def from_precision (precision : onp .ToFloat2D , covariance : onp .ToFloat2D | None = None ) -> CovViaPrecision : ...
29
27
@staticmethod
30
28
def from_cholesky (cholesky : onp .ToFloat2D ) -> CovViaCholesky : ...
31
29
@staticmethod
32
30
def from_eigendecomposition (eigendecomposition : tuple [onp .ToFloat1D , onp .ToFloat2D ]) -> CovViaEigendecomposition : ...
33
- def whiten ( self , / , x : onp . AnyIntegerArray | onp . AnyFloatingArray ) -> onp . ArrayND [ npc . floating ]: ...
34
- def colorize ( self , / , x : onp . AnyIntegerArray | onp . AnyFloatingArray ) -> onp . ArrayND [ npc . floating ]: ...
31
+
32
+ #
35
33
@property
36
34
def log_pdet (self , / ) -> np .float64 : ...
37
35
@property
38
36
def rank (self , / ) -> np .int_ : ...
39
37
@property
40
- def covariance (self , / ) -> onp .Array2D [_SCT_co ]: ...
38
+ def covariance (self , / ) -> onp .Array2D [_ScalarT_co ]: ...
41
39
@property
42
40
def shape (self , / ) -> tuple [int , int ]: ...
43
41
44
- class CovViaDiagonal (Covariance [_SCT_co ], Generic [_SCT_co ]):
42
+ #
43
+ def whiten (self , / , x : onp .ToFloatND ) -> onp .ArrayND [npc .floating ]: ...
44
+ def colorize (self , / , x : onp .ToFloatND ) -> onp .ArrayND [npc .floating ]: ...
45
+
46
+ class CovViaDiagonal (Covariance [_ScalarT_co ], Generic [_ScalarT_co ]):
45
47
@overload
46
- def __init__ (self : CovViaDiagonal [np .int_ ], / , diagonal : Sequence [ int ] ) -> None : ...
48
+ def __init__ (self : CovViaDiagonal [np .float64 ], / , diagonal : onp . ToJustFloat64_1D ) -> None : ...
47
49
@overload
48
- def __init__ (self : CovViaDiagonal [np .int_ | np . float64 ], / , diagonal : Sequence [ float ] ) -> None : ...
50
+ def __init__ (self : CovViaDiagonal [np .int_ ], / , diagonal : onp . ToJustInt64_1D ) -> None : ...
49
51
@overload
50
- def __init__ (self , / , diagonal : Sequence [ float | _SCT_co ] | onp .CanArrayND [ _SCT_co ]) -> None : ...
52
+ def __init__ (self , / , diagonal : onp .ToArray1D [ _ScalarT_co , _ScalarT_co ]) -> None : ...
51
53
52
54
class CovViaPrecision (Covariance [np .float64 ]):
53
55
def __init__ (self , / , precision : onp .ToFloat2D , covariance : onp .ToFloat2D | None = None ) -> None : ...
@@ -63,17 +65,17 @@ class _PSD(Protocol):
63
65
_M : onp .ArrayND [np .float64 ]
64
66
V : onp .ArrayND [np .float64 ]
65
67
U : onp .ArrayND [np .float64 ]
66
- eps : np . float64 | float
67
- log_pdet : np . float64 | float
68
- cond : np . float64 | float
68
+ eps : float
69
+ log_pdet : float
70
+ cond : float
69
71
rank : int
70
72
71
73
@property
72
74
def pinv (self , / ) -> onp .ArrayND [npc .floating ]: ...
73
75
74
76
class CovViaPSD (Covariance [np .float64 ]):
75
77
_LP : Final [onp .ArrayND [np .float64 ]]
76
- _log_pdet : Final [np . float64 | float ]
78
+ _log_pdet : Final [float ]
77
79
_rank : Final [int ]
78
80
_covariance : Final [onp .ArrayND [np .float64 ]]
79
81
_shape : tuple [int , int ]
0 commit comments