Skip to content

Commit b3b1d4e

Browse files
committed
special: improved logsumexp annotations
1 parent abc6a67 commit b3b1d4e

File tree

2 files changed

+272
-33
lines changed

2 files changed

+272
-33
lines changed

scipy-stubs/special/_logsumexp.pyi

Lines changed: 199 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,237 @@
1-
from typing import overload
1+
from typing import Any, TypeVar, overload
22

33
import numpy as np
44
import optype.numpy as onp
5+
import optype.numpy.compat as npc
56

67
from scipy._typing import AnyShape, Falsy, Truthy
78

89
__all__ = ["log_softmax", "logsumexp", "softmax"]
910

10-
@overload
11+
_InexactT = TypeVar("_InexactT", bound=npc.inexact)
12+
_FloatingT = TypeVar("_FloatingT", bound=npc.floating)
13+
_CFloatingT = TypeVar("_CFloatingT", bound=npc.complexfloating)
14+
15+
###
16+
17+
@overload # 0d/nd T, axis=None (default), keepdims=False (default)
18+
def logsumexp(
19+
a: _InexactT | onp.ToArrayND[_InexactT, _InexactT],
20+
axis: None = None,
21+
b: onp.ToComplex | onp.ToComplexND | None = None,
22+
keepdims: Falsy = False,
23+
return_sign: Falsy = False,
24+
) -> _InexactT: ...
25+
@overload # 0d/nd +float , axis=None (default), keepdims=False (default)
1126
def logsumexp(
12-
a: onp.ToFloat, axis: AnyShape | None = None, b: onp.ToFloat | None = None, keepdims: bool = False, return_sign: Falsy = False
27+
a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND,
28+
axis: None = None,
29+
b: onp.ToFloat64 | onp.ToFloat64_ND | None = None,
30+
keepdims: Falsy = False,
31+
return_sign: Falsy = False,
1332
) -> np.float64: ...
14-
@overload
33+
@overload # 0d/nd ~complex, axis=None (default), keepdims=False (default)
1534
def logsumexp(
16-
a: onp.ToComplex,
35+
a: onp.ToJustComplex128 | onp.ToJustComplex128_ND,
36+
axis: None = None,
37+
b: onp.ToComplex128 | onp.ToComplex128_ND | None = None,
38+
keepdims: Falsy = False,
39+
return_sign: Falsy = False,
40+
) -> np.complex128: ...
41+
@overload # 0d/nd T, keepdims=True
42+
def logsumexp(
43+
a: _InexactT | onp.ToArrayND[_InexactT, _InexactT],
1744
axis: AnyShape | None = None,
18-
b: onp.ToFloat | None = None,
19-
keepdims: bool = False,
45+
b: onp.ToComplex | onp.ToComplexND | None = None,
46+
*,
47+
keepdims: Truthy,
2048
return_sign: Falsy = False,
21-
) -> np.float64 | np.complex128: ...
22-
@overload
49+
) -> onp.ArrayND[_InexactT]: ...
50+
@overload # 0d/nd +float, keepdims=True
51+
def logsumexp(
52+
a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND,
53+
axis: AnyShape | None = None,
54+
b: onp.ToFloat64 | onp.ToFloat64_ND | None = None,
55+
*,
56+
keepdims: Truthy,
57+
return_sign: Falsy = False,
58+
) -> onp.ArrayND[np.float64]: ...
59+
@overload # 0d/nd ~complex, keepdims=True
60+
def logsumexp(
61+
a: onp.ToJustComplex128 | onp.ToJustComplex128_ND,
62+
axis: AnyShape | None = None,
63+
b: onp.ToComplex128 | onp.ToComplex128_ND | None = None,
64+
*,
65+
keepdims: Truthy,
66+
return_sign: Falsy = False,
67+
) -> onp.ArrayND[np.complex128]: ...
68+
@overload # 0d/nd T, axis=<given>
69+
def logsumexp(
70+
a: _InexactT | onp.ToArrayND[_InexactT, _InexactT],
71+
axis: AnyShape,
72+
b: onp.ToComplex | onp.ToComplexND | None = None,
73+
*,
74+
keepdims: Falsy = False,
75+
return_sign: Falsy = False,
76+
) -> onp.ArrayND[_InexactT] | Any: ...
77+
@overload # 0d/nd +float, axis=<given>
78+
def logsumexp(
79+
a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND,
80+
axis: AnyShape,
81+
b: onp.ToFloat64 | onp.ToFloat64_ND | None = None,
82+
keepdims: Falsy = False,
83+
return_sign: Falsy = False,
84+
) -> onp.ArrayND[np.float64] | Any: ...
85+
@overload # 0d/nd ~complex, axis=<given>
86+
def logsumexp(
87+
a: onp.ToJustComplex128 | onp.ToJustComplex128_ND,
88+
axis: AnyShape,
89+
b: onp.ToComplex128 | onp.ToComplex128_ND | None = None,
90+
keepdims: Falsy = False,
91+
return_sign: Falsy = False,
92+
) -> onp.ArrayND[np.complex128] | Any: ...
93+
@overload # floating fallback, return_sign=False
2394
def logsumexp(
24-
a: onp.ToFloatND,
95+
a: onp.ToFloat | onp.ToFloatND,
2596
axis: AnyShape | None = None,
2697
b: onp.ToFloat | onp.ToFloatND | None = None,
2798
keepdims: bool = False,
2899
return_sign: Falsy = False,
29-
) -> np.float64 | onp.ArrayND[np.float64]: ...
30-
@overload
100+
) -> onp.ArrayND[np.float64 | Any] | Any: ...
101+
@overload # ccomplex fallback, return_sign=False
31102
def logsumexp(
32-
a: onp.ToComplexND,
103+
a: onp.ToComplex | onp.ToComplexND,
33104
axis: AnyShape | None = None,
34-
b: onp.ToFloat | onp.ToFloatND | None = None,
105+
b: onp.ToComplex | onp.ToComplexND | None = None,
35106
keepdims: bool = False,
36107
return_sign: Falsy = False,
37-
) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ...
38-
@overload
108+
) -> onp.ArrayND[np.complex128 | Any] | Any: ...
109+
@overload # 0d/nd T@floating, axis=None (default), keepdims=False (default), return_sign=True
39110
def logsumexp(
40-
a: onp.ToFloat, axis: AnyShape | None = None, b: onp.ToFloat | None = None, keepdims: bool = False, *, return_sign: Truthy
41-
) -> tuple[np.float64, bool | np.bool_]: ...
42-
@overload
111+
a: _FloatingT | onp.ToArrayND[_FloatingT, _FloatingT],
112+
axis: None = None,
113+
b: onp.ToFloat | onp.ToFloatND | None = None,
114+
keepdims: Falsy = False,
115+
*,
116+
return_sign: Truthy,
117+
) -> tuple[_FloatingT, _FloatingT]: ...
118+
@overload # 0d/nd +float , axis=None (default), keepdims=False (default), return_sign=True
43119
def logsumexp(
44-
a: onp.ToComplex, axis: AnyShape | None = None, b: onp.ToFloat | None = None, keepdims: bool = False, *, return_sign: Truthy
45-
) -> tuple[np.float64 | np.complex128, bool | np.bool_]: ...
46-
@overload
120+
a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND,
121+
axis: None = None,
122+
b: onp.ToFloat64 | onp.ToFloat64_ND | None = None,
123+
keepdims: Falsy = False,
124+
*,
125+
return_sign: Truthy,
126+
) -> tuple[np.float64, np.float64]: ...
127+
@overload # 0d/nd ~complex, axis=None (default), keepdims=False (default), return_sign=True
128+
def logsumexp(
129+
a: onp.ToJustComplex128 | onp.ToJustComplex128_ND,
130+
axis: None = None,
131+
b: onp.ToComplex128 | onp.ToComplex128_ND | None = None,
132+
keepdims: Falsy = False,
133+
*,
134+
return_sign: Truthy,
135+
) -> tuple[np.float64, np.complex128]: ...
136+
@overload # 0d/nd T@complexfloating, axis=None (default), keepdims=False (default), return_sign=True
137+
def logsumexp(
138+
a: _CFloatingT | onp.ToArrayND[_CFloatingT, _CFloatingT],
139+
axis: None = None,
140+
b: onp.ToFloat | onp.ToFloatND | None = None,
141+
keepdims: Falsy = False,
142+
*,
143+
return_sign: Truthy,
144+
) -> tuple[npc.floating, _CFloatingT]: ...
145+
@overload # 0d/nd T@floatinv, keepdims=True, return_sign=True
47146
def logsumexp(
48-
a: onp.ToFloatND,
147+
a: _FloatingT | onp.ToArrayND[_FloatingT, _FloatingT],
49148
axis: AnyShape | None = None,
50149
b: onp.ToFloat | onp.ToFloatND | None = None,
51-
keepdims: bool = False,
52150
*,
151+
keepdims: Truthy,
53152
return_sign: Truthy,
54-
) -> tuple[np.float64, bool | np.bool_] | tuple[onp.ArrayND[np.float64], onp.ArrayND[np.bool_]]: ...
55-
@overload
153+
) -> tuple[onp.ArrayND[_FloatingT], onp.ArrayND[_FloatingT]]: ...
154+
@overload # 0d/nd +float, keepdims=True, return_sign=True
155+
def logsumexp(
156+
a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND,
157+
axis: AnyShape | None = None,
158+
b: onp.ToFloat64 | onp.ToFloat64_ND | None = None,
159+
*,
160+
keepdims: Truthy,
161+
return_sign: Truthy,
162+
) -> tuple[onp.ArrayND[np.float64], onp.ArrayND[np.float64]]: ...
163+
@overload # 0d/nd ~complex, keepdims=True, return_sign=True
164+
def logsumexp(
165+
a: onp.ToJustComplex128 | onp.ToJustComplex128_ND,
166+
axis: AnyShape | None = None,
167+
b: onp.ToComplex128 | onp.ToComplex128_ND | None = None,
168+
*,
169+
keepdims: Truthy,
170+
return_sign: Truthy,
171+
) -> tuple[onp.ArrayND[np.float64], onp.ArrayND[np.complex128]]: ...
172+
@overload # 0d/nd T@complexfloating, keepdims=True, return_sign=True
173+
def logsumexp(
174+
a: _CFloatingT | onp.ToArrayND[_CFloatingT, _CFloatingT],
175+
axis: AnyShape | None = None,
176+
b: onp.ToComplex | onp.ToComplexND | None = None,
177+
*,
178+
keepdims: Truthy,
179+
return_sign: Truthy,
180+
) -> tuple[onp.ArrayND[npc.floating], onp.ArrayND[_CFloatingT]]: ...
181+
@overload # 0d/nd T@floatinv, axis=<given>, return_sign=True
182+
def logsumexp(
183+
a: _FloatingT | onp.ToArrayND[_FloatingT, _FloatingT],
184+
axis: AnyShape,
185+
b: onp.ToFloat | onp.ToFloatND | None = None,
186+
keepdims: Falsy = False,
187+
*,
188+
return_sign: Truthy,
189+
) -> tuple[onp.ArrayND[_FloatingT] | Any, onp.ArrayND[_FloatingT] | Any]: ...
190+
@overload # 0d/nd +float, axis=<given>, return_sign=True
56191
def logsumexp(
57-
a: onp.ToComplexND,
192+
a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND,
193+
axis: AnyShape,
194+
b: onp.ToFloat64 | onp.ToFloat64_ND | None = None,
195+
keepdims: Falsy = False,
196+
*,
197+
return_sign: Truthy,
198+
) -> tuple[onp.ArrayND[np.float64] | Any, onp.ArrayND[np.float64] | Any]: ...
199+
@overload # 0d/nd ~complex, axis=<given>, return_sign=True
200+
def logsumexp(
201+
a: onp.ToJustComplex128 | onp.ToJustComplex128_ND,
202+
axis: AnyShape,
203+
b: onp.ToComplex128 | onp.ToComplex128_ND | None = None,
204+
keepdims: Falsy = False,
205+
*,
206+
return_sign: Truthy,
207+
) -> tuple[onp.ArrayND[np.float64] | Any, onp.ArrayND[np.complex128] | Any]: ...
208+
@overload # 0d/nd T@complexfloating, axis=<given>, return_sign=True
209+
def logsumexp(
210+
a: _CFloatingT | onp.ToArrayND[_CFloatingT, _CFloatingT],
211+
axis: AnyShape,
212+
b: onp.ToComplex | onp.ToComplexND | None = None,
213+
keepdims: Falsy = False,
214+
*,
215+
return_sign: Truthy,
216+
) -> tuple[onp.ArrayND[npc.floating] | Any, onp.ArrayND[_CFloatingT] | Any]: ...
217+
@overload # floating fallback, return_sign=True
218+
def logsumexp(
219+
a: onp.ToFloat | onp.ToFloatND,
58220
axis: AnyShape | None = None,
59221
b: onp.ToFloat | onp.ToFloatND | None = None,
60222
keepdims: bool = False,
61223
*,
62224
return_sign: Truthy,
63-
) -> (
64-
tuple[np.float64 | np.complex128, bool | np.bool_] | tuple[onp.ArrayND[np.float64 | np.complex128], onp.ArrayND[np.bool_]]
65-
): ...
225+
) -> tuple[onp.ArrayND[np.float64 | Any] | Any, onp.ArrayND[np.float64 | Any] | Any]: ...
226+
@overload # ccomplex fallback, return_sign=True
227+
def logsumexp(
228+
a: onp.ToComplex | onp.ToComplexND,
229+
axis: AnyShape | None = None,
230+
b: onp.ToComplex | onp.ToComplexND | None = None,
231+
keepdims: bool = False,
232+
*,
233+
return_sign: Truthy,
234+
) -> tuple[onp.ArrayND[np.float64 | Any] | Any, onp.ArrayND[np.complex128 | Any] | Any]: ...
66235

67236
#
68237
@overload

tests/special/test_logsumexp.pyi

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,77 @@
1+
from typing import Any, assert_type
2+
13
import numpy as np
4+
import numpy.typing as npt
25

36
from scipy.special import logsumexp
47

5-
# https://github.com/scipy/scipy-stubs/issues/697
6-
x = np.asarray([1, 2, 3], dtype=np.float64)
7-
logsumexp(x)
8+
py_f_0d: float
9+
py_c_0d: complex
10+
py_f_1d: list[float]
11+
py_c_1d: list[complex]
12+
13+
f16_0d: np.float16
14+
f16_1d: np.ndarray[tuple[int], np.dtype[np.float16]]
15+
c64_0d: np.complex64
16+
c64_1d: np.ndarray[tuple[int], np.dtype[np.complex64]]
17+
18+
###
19+
# logsumexp
20+
21+
assert_type(logsumexp(py_f_0d), np.float64)
22+
assert_type(logsumexp(py_f_1d), np.float64)
23+
assert_type(logsumexp(py_c_0d), np.complex128)
24+
assert_type(logsumexp(py_c_1d), np.complex128)
25+
assert_type(logsumexp(f16_0d), np.float16)
26+
assert_type(logsumexp(f16_1d), np.float16)
27+
assert_type(logsumexp(c64_0d), np.complex64)
28+
assert_type(logsumexp(c64_1d), np.complex64)
29+
30+
assert_type(logsumexp(py_f_0d, keepdims=True), npt.NDArray[np.float64])
31+
assert_type(logsumexp(py_f_1d, keepdims=True), npt.NDArray[np.float64])
32+
assert_type(logsumexp(py_c_0d, keepdims=True), npt.NDArray[np.complex128])
33+
assert_type(logsumexp(py_c_1d, keepdims=True), npt.NDArray[np.complex128])
34+
assert_type(logsumexp(f16_0d, keepdims=True), npt.NDArray[np.float16])
35+
assert_type(logsumexp(f16_1d, keepdims=True), npt.NDArray[np.float16])
36+
assert_type(logsumexp(c64_0d, keepdims=True), npt.NDArray[np.complex64])
37+
assert_type(logsumexp(c64_1d, keepdims=True), npt.NDArray[np.complex64])
38+
39+
assert_type(logsumexp(py_f_0d, axis=0), npt.NDArray[np.float64] | Any)
40+
assert_type(logsumexp(py_f_1d, axis=0), npt.NDArray[np.float64] | Any)
41+
assert_type(logsumexp(py_c_0d, axis=0), npt.NDArray[np.complex128] | Any)
42+
assert_type(logsumexp(py_c_1d, axis=0), npt.NDArray[np.complex128] | Any)
43+
assert_type(logsumexp(f16_0d, axis=0), npt.NDArray[np.float16] | Any)
44+
assert_type(logsumexp(f16_1d, axis=0), npt.NDArray[np.float16] | Any)
45+
assert_type(logsumexp(c64_0d, axis=0), npt.NDArray[np.complex64] | Any)
46+
assert_type(logsumexp(c64_1d, axis=0), npt.NDArray[np.complex64] | Any)
47+
48+
assert_type(logsumexp(py_f_0d, return_sign=True), tuple[np.float64, np.float64])
49+
assert_type(logsumexp(py_f_1d, return_sign=True), tuple[np.float64, np.float64])
50+
assert_type(logsumexp(py_c_0d, return_sign=True), tuple[np.float64, np.complex128])
51+
assert_type(logsumexp(py_c_1d, return_sign=True), tuple[np.float64, np.complex128])
52+
assert_type(logsumexp(f16_0d, return_sign=True), tuple[np.float16, np.float16])
53+
assert_type(logsumexp(f16_1d, return_sign=True), tuple[np.float16, np.float16])
54+
assert_type(logsumexp(c64_0d, return_sign=True), tuple[np.floating[Any], np.complex64])
55+
assert_type(logsumexp(c64_1d, return_sign=True), tuple[np.floating[Any], np.complex64])
56+
57+
assert_type(logsumexp(py_f_0d, keepdims=True, return_sign=True), tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]])
58+
assert_type(logsumexp(py_f_1d, keepdims=True, return_sign=True), tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]])
59+
assert_type(logsumexp(py_c_0d, keepdims=True, return_sign=True), tuple[npt.NDArray[np.float64], npt.NDArray[np.complex128]])
60+
assert_type(logsumexp(py_c_1d, keepdims=True, return_sign=True), tuple[npt.NDArray[np.float64], npt.NDArray[np.complex128]])
61+
assert_type(logsumexp(f16_0d, keepdims=True, return_sign=True), tuple[npt.NDArray[np.float16], npt.NDArray[np.float16]])
62+
assert_type(logsumexp(f16_1d, keepdims=True, return_sign=True), tuple[npt.NDArray[np.float16], npt.NDArray[np.float16]])
63+
assert_type(logsumexp(c64_0d, keepdims=True, return_sign=True), tuple[npt.NDArray[np.floating[Any]], npt.NDArray[np.complex64]])
64+
assert_type(logsumexp(c64_1d, keepdims=True, return_sign=True), tuple[npt.NDArray[np.floating[Any]], npt.NDArray[np.complex64]])
65+
66+
assert_type(logsumexp(py_f_0d, axis=0, return_sign=True), tuple[npt.NDArray[np.float64] | Any, npt.NDArray[np.float64] | Any])
67+
assert_type(logsumexp(py_f_1d, axis=0, return_sign=True), tuple[npt.NDArray[np.float64] | Any, npt.NDArray[np.float64] | Any])
68+
assert_type(logsumexp(py_c_0d, axis=0, return_sign=True), tuple[npt.NDArray[np.float64] | Any, npt.NDArray[np.complex128] | Any])
69+
assert_type(logsumexp(py_c_1d, axis=0, return_sign=True), tuple[npt.NDArray[np.float64] | Any, npt.NDArray[np.complex128] | Any])
70+
assert_type(logsumexp(f16_0d, axis=0, return_sign=True), tuple[npt.NDArray[np.float16] | Any, npt.NDArray[np.float16] | Any])
71+
assert_type(logsumexp(f16_1d, axis=0, return_sign=True), tuple[npt.NDArray[np.float16] | Any, npt.NDArray[np.float16] | Any])
72+
assert_type(
73+
logsumexp(c64_0d, axis=0, return_sign=True), tuple[npt.NDArray[np.floating[Any]] | Any, npt.NDArray[np.complex64] | Any]
74+
)
75+
assert_type(
76+
logsumexp(c64_1d, axis=0, return_sign=True), tuple[npt.NDArray[np.floating[Any]] | Any, npt.NDArray[np.complex64] | Any]
77+
)

0 commit comments

Comments
 (0)