1
1
from collections .abc import Mapping
2
- from typing import Any , Literal , Protocol , TypeAlias , TypeVar , overload , type_check_only
3
- from typing_extensions import deprecated
2
+ from typing import Any , Literal , Protocol , TypeAlias , overload , type_check_only
3
+ from typing_extensions import TypeVar , deprecated
4
4
5
5
import numpy as np
6
6
import optype .numpy as onp
@@ -24,12 +24,10 @@ __all__ = [
24
24
]
25
25
26
26
_SparseT = TypeVar ("_SparseT" , bound = _spbase )
27
+ _NumberT_contra = TypeVar ("_NumberT_contra" , bound = npc .number , contravariant = True )
28
+ _InexactT_co = TypeVar ("_InexactT_co" , bound = np .float32 | np .float64 | np .complex64 | np .complex128 , covariant = True )
27
29
28
30
_PermcSpec : TypeAlias = Literal ["COLAMD" , "NATURAL" , "MMD_ATA" , "MMD_AT_PLUS_A" ]
29
- _Float1D : TypeAlias = onp .Array1D [np .float64 ]
30
- _Float2D : TypeAlias = onp .Array2D [np .float64 ]
31
- _Complex1D : TypeAlias = onp .Array1D [np .complex128 ]
32
- _Complex2D : TypeAlias = onp .Array2D [np .complex128 ]
33
31
34
32
_ToF32Mat : TypeAlias = _spbase [np .float32 , tuple [int , int ]] | onp .CanArray [tuple [Any , ...], np .dtype [np .float32 ]]
35
33
_ToF64Mat : TypeAlias = _spbase [np .float64 | npc .integer , tuple [int , int ]] | onp .ToInt2D | onp .ToJustFloat64_2D
@@ -42,31 +40,35 @@ _ToComplexMat: TypeAlias = _spbase[npc.complexfloating, tuple[int, int]] | onp.T
42
40
_ToInexactMat : TypeAlias = _spbase [Any , tuple [int , int ]] | onp .ToComplex128_2D
43
41
_ToInexactMatStrict : TypeAlias = _spbase [Any , tuple [int , int ]] | onp .ToComplex128Strict2D
44
42
45
- # TODO(jorenham): make generic (because safe casting rules apply, i.e. the current annotations are incorrect)
46
- # https://github.com/scipy/scipy-stubs/issues/677
43
+ _AsF32 : TypeAlias = npc .integer8 | npc .number16 | np .int32 | np .float16 | np .float32
44
+ _AsF64 : TypeAlias = npc .integer | np .float16 | np .float32 | np .float64
45
+ _AsC64 : TypeAlias = npc .integer8 | npc .integer16 | np .int32 | np .float16 | npc .inexact32
46
+ _AsC128 : TypeAlias = npc .integer | np .float16 | npc .inexact32 | npc .inexact64
47
+
47
48
@type_check_only
48
- class _Solve (Protocol ):
49
- @overload
50
- def __call__ (self , b : onp .Array1D [npc .integer | npc .floating ], / ) -> _Float1D : ...
51
- @overload
52
- def __call__ (self , b : onp .Array1D [npc .complexfloating ], / ) -> _Complex1D : ...
49
+ class _SuperLU_solve (Protocol [_NumberT_contra , _InexactT_co ]):
53
50
@overload
54
- def __call__ (self , b : onp .Array2D [ npc . integer | npc . floating ], / ) -> _Float2D : ...
51
+ def __call__ (self , rhs : onp .Array1D [ _NumberT_contra | np . bool_ ] ) -> onp . Array1D [ _InexactT_co ] : ...
55
52
@overload
56
- def __call__ (self , b : onp .Array2D [npc . complexfloating ], / ) -> _Complex2D : ...
53
+ def __call__ (self , rhs : onp .Array2D [_NumberT_contra | np . bool_ ] ) -> onp . Array2D [ _InexactT_co ] : ...
57
54
@overload
58
- def __call__ (self , b : onp .ArrayND [npc .integer | npc .floating ], / ) -> _Float1D | _Float2D : ...
59
- @overload
60
- def __call__ (self , b : onp .ArrayND [npc .complexfloating ], / ) -> _Complex1D | _Complex2D : ...
61
- @overload
62
- def __call__ (self , b : onp .ArrayND [npc .number ], / ) -> _Float1D | _Complex1D | _Float2D | _Complex2D : ...
55
+ def __call__ (self , rhs : onp .ArrayND [_NumberT_contra | np .bool_ ]) -> onp .ArrayND [_InexactT_co ]: ...
63
56
64
57
###
65
58
66
59
class MatrixRankWarning (UserWarning ): ...
67
60
68
- def use_solver (* , useUmfpack : bool = ..., assumeSortedIndices : bool = ...) -> None : ...
69
- def factorized (A : _ToInexactMat ) -> _Solve : ...
61
+ def use_solver (* , useUmfpack : bool = True , assumeSortedIndices : bool = False ) -> None : ...
62
+
63
+ # NOTE: the mypy ignores work around a mypy bug in overload overlap checking
64
+ @overload
65
+ def factorized (A : _ToF64Mat ) -> _SuperLU_solve [_AsF64 , np .float64 ]: ... # type: ignore[overload-overlap]
66
+ @overload
67
+ def factorized (A : _ToC128Mat ) -> _SuperLU_solve [_AsC128 , np .complex128 ]: ... # type: ignore[overload-overlap]
68
+ @overload
69
+ def factorized (A : _ToF32Mat ) -> _SuperLU_solve [_AsF32 , np .float32 ]: ...
70
+ @overload
71
+ def factorized (A : _ToC64Mat ) -> _SuperLU_solve [_AsC64 , np .complex64 ]: ...
70
72
71
73
#
72
74
@overload # 2d float, sparse 2d
0 commit comments