Skip to content

Commit ec34cb1

Browse files
authored
✨ improve linalg.norm shape- & scalar-type overloads (#247)
1 parent dbed9be commit ec34cb1

File tree

1 file changed

+131
-17
lines changed

1 file changed

+131
-17
lines changed

scipy-stubs/linalg/_misc.pyi

Lines changed: 131 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,152 @@
1-
from typing import Literal, overload
1+
from typing import Any, Literal, TypeAlias, TypeVar, overload
22

33
import numpy as np
44
import numpy.typing as npt
5+
import optype as op
56
import optype.numpy as onp
6-
import optype.typing as opt
77
from numpy.linalg import LinAlgError
88
from scipy._typing import AnyBool
99

1010
__all__ = ["LinAlgError", "LinAlgWarning", "norm"]
1111

12+
_Inf: TypeAlias = float
13+
_Order: TypeAlias = Literal["fro", "nuc", 0, 1, -1, 2, -2] | _Inf
14+
_Axis: TypeAlias = op.CanIndex | tuple[op.CanIndex, op.CanIndex]
15+
16+
_Falsy: TypeAlias = Literal[False, 0]
17+
_Truthy: TypeAlias = Literal[True, 1]
18+
19+
_SubScalar: TypeAlias = np.complex128 | np.float64 | np.integer[Any] | np.bool_
20+
21+
_NBitT = TypeVar("_NBitT", bound=npt.NBitBase)
22+
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
23+
24+
###
25+
1226
class LinAlgWarning(RuntimeWarning): ...
1327

14-
@overload
28+
@overload # scalar, axis: None = ...
1529
def norm(
16-
a: npt.ArrayLike,
17-
ord: Literal["fro", "nuc", 0, 1, -1, 2, -2] | float | None = None,
30+
a: complex | _SubScalar,
31+
ord: _Order | None = None,
1832
axis: None = None,
19-
keepdims: AnyBool = False,
33+
keepdims: op.CanBool = False,
2034
check_finite: AnyBool = True,
2135
) -> np.float64: ...
22-
@overload
36+
@overload # inexact, axis: None = ...
2337
def norm(
24-
a: npt.ArrayLike,
25-
ord: Literal["fro", "nuc", 0, 1, -1, 2, -2] | float | None,
26-
axis: opt.AnyInt | tuple[opt.AnyInt, ...],
27-
keepdims: AnyBool = False,
38+
a: np.inexact[_NBitT],
39+
ord: _Order | None = None,
40+
axis: None = None,
41+
keepdims: op.CanBool = False,
2842
check_finite: AnyBool = True,
29-
) -> np.float64 | onp.ArrayND[np.float64]: ...
30-
@overload
43+
) -> np.floating[_NBitT]: ...
44+
@overload # scalar array, axis: None = ..., keepdims: False = ...
3145
def norm(
32-
a: npt.ArrayLike,
33-
ord: Literal["fro", "nuc", 0, 1, -1, 2, -2] | float | None = None,
46+
a: onp.CanArrayND[_SubScalar] | onp.SequenceND[onp.CanArrayND[_SubScalar]] | onp.SequenceND[_SubScalar],
47+
ord: _Order | None = None,
48+
axis: None = None,
49+
keepdims: _Falsy = False,
50+
check_finite: AnyBool = True,
51+
) -> np.float64: ...
52+
@overload # float64-coercible array, keepdims: True (positional)
53+
def norm(
54+
a: onp.CanArrayND[_SubScalar, _ShapeT],
55+
ord: _Order | None,
56+
axis: _Axis | None,
57+
keepdims: _Truthy,
58+
check_finite: AnyBool = True,
59+
) -> onp.ArrayND[np.float64, _ShapeT]: ...
60+
@overload # float64-coercible array, keepdims: True (keyword)
61+
def norm(
62+
a: onp.CanArrayND[_SubScalar, _ShapeT],
63+
ord: _Order | None = None,
64+
axis: _Axis | None = None,
65+
*,
66+
keepdims: _Truthy,
67+
check_finite: AnyBool = True,
68+
) -> onp.ArrayND[np.float64, _ShapeT]: ...
69+
@overload # float64-coercible array-like, keepdims: True (positional)
70+
def norm(
71+
a: onp.SequenceND[onp.CanArrayND[_SubScalar]] | onp.SequenceND[complex | _SubScalar],
72+
ord: _Order | None,
73+
axis: _Axis | None,
74+
keepdims: _Truthy,
75+
check_finite: AnyBool = True,
76+
) -> onp.ArrayND[np.float64]: ...
77+
@overload # float64-coercible array-like, keepdims: True (keyword)
78+
def norm(
79+
a: onp.SequenceND[onp.CanArrayND[_SubScalar]] | onp.SequenceND[complex | _SubScalar],
80+
ord: _Order | None = None,
81+
axis: _Axis | None = None,
3482
*,
35-
axis: opt.AnyInt | tuple[opt.AnyInt, ...],
83+
keepdims: _Truthy,
84+
check_finite: AnyBool = True,
85+
) -> onp.ArrayND[np.float64]: ...
86+
@overload # shaped inexact array, keepdims: True (positional)
87+
def norm(
88+
a: onp.CanArrayND[np.inexact[_NBitT], _ShapeT],
89+
ord: _Order | None,
90+
axis: _Axis | None,
91+
keepdims: _Truthy,
92+
check_finite: AnyBool = True,
93+
) -> onp.ArrayND[np.floating[_NBitT], _ShapeT]: ...
94+
@overload # shaped inexact array, keepdims: True (keyword)
95+
def norm(
96+
a: onp.CanArrayND[np.inexact[_NBitT], _ShapeT],
97+
ord: _Order | None = None,
98+
axis: _Axis | None = None,
99+
*,
100+
keepdims: _Truthy,
101+
check_finite: AnyBool = True,
102+
) -> onp.ArrayND[np.floating[_NBitT], _ShapeT]: ...
103+
@overload # scalar array-like, keepdims: True (positional)
104+
def norm(
105+
a: onp.SequenceND[onp.CanArrayND[np.inexact[_NBitT]]] | onp.SequenceND[np.inexact[_NBitT]],
106+
ord: _Order | None,
107+
axis: _Axis | None,
108+
keepdims: _Truthy,
109+
check_finite: AnyBool = True,
110+
) -> onp.ArrayND[np.floating[_NBitT]]: ...
111+
@overload # scalar array-like, keepdims: True (keyword)
112+
def norm(
113+
a: onp.SequenceND[onp.CanArrayND[np.inexact[_NBitT]]] | onp.SequenceND[np.inexact[_NBitT]],
114+
ord: _Order | None = None,
115+
axis: _Axis | None = None,
116+
*,
117+
keepdims: _Truthy,
118+
check_finite: AnyBool = True,
119+
) -> onp.ArrayND[np.floating[_NBitT]]: ...
120+
@overload # array-like, axis: None = ..., keepdims: False = ...
121+
def norm(
122+
a: onp.ToComplexND,
123+
ord: _Order | None = None,
124+
axis: None = None,
125+
keepdims: _Falsy = False,
126+
check_finite: AnyBool = True,
127+
) -> np.float64: ...
128+
@overload # array-like, keepdims: True (positional)
129+
def norm(
130+
a: onp.ToComplexND,
131+
ord: _Order | None,
132+
axis: _Axis | None,
133+
keepdims: _Truthy,
134+
check_finite: AnyBool = True,
135+
) -> onp.ArrayND[np.floating[Any]]: ...
136+
@overload # array-like, keepdims: True (keyword)
137+
def norm(
138+
a: onp.ToComplexND,
139+
ord: _Order | None = None,
140+
axis: _Axis | None = None,
141+
*,
142+
keepdims: _Truthy,
143+
check_finite: AnyBool = True,
144+
) -> onp.ArrayND[np.floating[Any]]: ...
145+
@overload # catch-all
146+
def norm(
147+
a: npt.ArrayLike,
148+
ord: _Order | None = None,
149+
axis: _Axis | None = None,
36150
keepdims: AnyBool = False,
37151
check_finite: AnyBool = True,
38-
) -> np.float64 | onp.ArrayND[np.float64]: ...
152+
) -> np.floating[Any] | onp.ArrayND[np.floating[Any]]: ...

0 commit comments

Comments
 (0)