1
1
# mypy: disable-error-code=overload-overlap
2
2
3
+ from collections .abc import Sequence
3
4
from typing import Final , Literal , TypeAlias , TypeVar , overload
4
5
5
6
import numpy as np
@@ -23,20 +24,20 @@ __all__ = [
23
24
"solveh_banded" ,
24
25
]
25
26
27
+ _ShapeT = TypeVar ("_ShapeT" , bound = tuple [int , ...])
26
28
_T = TypeVar ("_T" )
29
+
27
30
_Tuple2 : TypeAlias = tuple [_T , _T ]
28
31
_COrCR : TypeAlias = _T | _Tuple2 [_T ]
29
32
30
33
_Float : TypeAlias = npc .floating
31
34
_Float0D : TypeAlias = onp .Array0D [_Float ]
32
35
_Float1D : TypeAlias = onp .Array1D [_Float ]
33
- _Float2D : TypeAlias = onp .Array2D [_Float ]
34
36
_FloatND : TypeAlias = onp .ArrayND [_Float ]
35
37
36
38
_Inexact : TypeAlias = npc .inexact
37
39
_Inexact0D : TypeAlias = onp .Array0D [_Inexact ]
38
40
_Inexact1D : TypeAlias = onp .Array1D [_Inexact ]
39
- _Inexact2D : TypeAlias = onp .Array2D [_Inexact ]
40
41
_InexactND : TypeAlias = onp .ArrayND [_Inexact ]
41
42
42
43
_InputFloat : TypeAlias = onp .ToArrayND [float , np .float64 | np .longdouble | npc .integer | np .bool_ ]
@@ -992,15 +993,48 @@ def solve_circulant(
992
993
outaxis : int = 0 ,
993
994
) -> onp .ArrayND [npc .inexact ]: ...
994
995
995
- # TODO(jorenham): improve this
996
- @overload # floating 2d
997
- def inv (a : onp .ToFloatStrict2D , overwrite_a : bool = False , check_finite : bool = True ) -> _Float2D : ...
998
- @overload # floating
999
- def inv (a : onp .ToFloatND , overwrite_a : bool = False , check_finite : bool = True ) -> _FloatND : ...
1000
- @overload # complexfloating 2d
1001
- def inv (a : onp .ToComplexStrict2D , overwrite_a : bool = False , check_finite : bool = True ) -> _Inexact2D : ...
1002
- @overload # complexfloating
1003
- def inv (a : onp .ToComplexND , overwrite_a : bool = False , check_finite : bool = True ) -> _InexactND : ...
996
+ #
997
+
998
+ @overload # 2d bool sequence
999
+ def inv (a : Sequence [Sequence [bool ]], overwrite_a : bool = False , check_finite : bool = True ) -> onp .Array2D [np .float32 ]: ...
1000
+ @overload # Nd bool sequence
1001
+ def inv (a : Sequence [onp .SequenceND [bool ]], overwrite_a : bool = False , check_finite : bool = True ) -> onp .ArrayND [np .float32 ]: ...
1002
+ @overload # 2d float or int sequence
1003
+ def inv (
1004
+ a : Sequence [Sequence [op .JustFloat | op .JustInt ]], overwrite_a : bool = False , check_finite : bool = True
1005
+ ) -> onp .Array2D [np .float64 ]: ...
1006
+ @overload # Nd float or int sequence
1007
+ def inv (
1008
+ a : Sequence [onp .SequenceND [op .JustFloat | op .JustInt ]], overwrite_a : bool = False , check_finite : bool = True
1009
+ ) -> onp .ArrayND [np .float64 ]: ...
1010
+ @overload # 2d complex sequence
1011
+ def inv (
1012
+ a : Sequence [Sequence [op .JustComplex ]], overwrite_a : bool = False , check_finite : bool = True
1013
+ ) -> onp .Array2D [np .complex128 ]: ...
1014
+ @overload # Nd complex sequence
1015
+ def inv (
1016
+ a : Sequence [onp .SequenceND [op .JustComplex ]], overwrite_a : bool = False , check_finite : bool = True
1017
+ ) -> onp .ArrayND [np .complex128 ]: ...
1018
+ @overload # generic shape, as float32
1019
+ def inv (
1020
+ a : onp .CanArrayND [np .float32 | npc .number16 | npc .integer8 | np .bool_ , _ShapeT ],
1021
+ overwrite_a : bool = False ,
1022
+ check_finite : bool = True ,
1023
+ ) -> onp .ArrayND [np .float32 , _ShapeT ]: ...
1024
+ @overload # generic shape, as float64
1025
+ def inv (
1026
+ a : onp .CanArrayND [np .float64 | np .longdouble | npc .integer64 | npc .integer32 , _ShapeT ],
1027
+ overwrite_a : bool = False ,
1028
+ check_finite : bool = True ,
1029
+ ) -> onp .ArrayND [np .float64 , _ShapeT ]: ...
1030
+ @overload # generic shape, as complex64
1031
+ def inv (
1032
+ a : onp .CanArrayND [np .complex64 , _ShapeT ], overwrite_a : bool = False , check_finite : bool = True
1033
+ ) -> onp .ArrayND [np .complex64 , _ShapeT ]: ...
1034
+ @overload # generic shape, as complex128
1035
+ def inv (
1036
+ a : onp .CanArrayND [np .complex128 | np .clongdouble , _ShapeT ], overwrite_a : bool = False , check_finite : bool = True
1037
+ ) -> onp .ArrayND [np .complex128 , _ShapeT ]: ...
1004
1038
1005
1039
# TODO(jorenham): improve this
1006
1040
@overload # floating 2d
0 commit comments