Skip to content

Commit 9c447a3

Browse files
authored
special: improved logsumexp and [log_]softmax (#699)
2 parents abc6a67 + b721780 commit 9c447a3

File tree

2 files changed

+329
-52
lines changed

2 files changed

+329
-52
lines changed

scipy-stubs/special/_logsumexp.pyi

Lines changed: 242 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,279 @@
1-
from typing import overload
1+
from typing import Any, TypeVar, overload
22

33
import numpy as np
44
import optype.numpy as onp
5+
import optype.numpy.compat as npc
56

67
from scipy._typing import AnyShape, Falsy, Truthy
78

89
__all__ = ["log_softmax", "logsumexp", "softmax"]
910

10-
@overload
11+
_InexactT = TypeVar("_InexactT", bound=npc.inexact)
12+
_FloatingT = TypeVar("_FloatingT", bound=npc.floating)
13+
_CFloatingT = TypeVar("_CFloatingT", bound=npc.complexfloating)
14+
_InexactOrArrayT = TypeVar("_InexactOrArrayT", bound=npc.inexact | onp.ArrayND[npc.inexact])
15+
16+
###
17+
18+
@overload # 0d/nd T, axis=None (default), keepdims=False (default)
1119
def logsumexp(
12-
a: onp.ToFloat, axis: AnyShape | None = None, b: onp.ToFloat | None = None, keepdims: bool = False, return_sign: Falsy = False
20+
a: _InexactT | onp.ToArrayND[_InexactT, _InexactT],
21+
axis: None = None,
22+
b: onp.ToComplex | onp.ToComplexND | None = None,
23+
keepdims: Falsy = False,
24+
return_sign: Falsy = False,
25+
) -> _InexactT: ...
26+
@overload # 0d/nd +float , axis=None (default), keepdims=False (default)
27+
def logsumexp(
28+
a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND,
29+
axis: None = None,
30+
b: onp.ToFloat64 | onp.ToFloat64_ND | None = None,
31+
keepdims: Falsy = False,
32+
return_sign: Falsy = False,
1333
) -> np.float64: ...
14-
@overload
34+
@overload # 0d/nd ~complex, axis=None (default), keepdims=False (default)
1535
def logsumexp(
16-
a: onp.ToComplex,
36+
a: onp.ToJustComplex128 | onp.ToJustComplex128_ND,
37+
axis: None = None,
38+
b: onp.ToComplex128 | onp.ToComplex128_ND | None = None,
39+
keepdims: Falsy = False,
40+
return_sign: Falsy = False,
41+
) -> np.complex128: ...
42+
@overload # 0d/nd T, keepdims=True
43+
def logsumexp(
44+
a: _InexactT | onp.ToArrayND[_InexactT, _InexactT],
1745
axis: AnyShape | None = None,
18-
b: onp.ToFloat | None = None,
19-
keepdims: bool = False,
46+
b: onp.ToComplex | onp.ToComplexND | None = None,
47+
*,
48+
keepdims: Truthy,
49+
return_sign: Falsy = False,
50+
) -> onp.ArrayND[_InexactT]: ...
51+
@overload # 0d/nd +float, keepdims=True
52+
def logsumexp(
53+
a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND,
54+
axis: AnyShape | None = None,
55+
b: onp.ToFloat64 | onp.ToFloat64_ND | None = None,
56+
*,
57+
keepdims: Truthy,
58+
return_sign: Falsy = False,
59+
) -> onp.ArrayND[np.float64]: ...
60+
@overload # 0d/nd ~complex, keepdims=True
61+
def logsumexp(
62+
a: onp.ToJustComplex128 | onp.ToJustComplex128_ND,
63+
axis: AnyShape | None = None,
64+
b: onp.ToComplex128 | onp.ToComplex128_ND | None = None,
65+
*,
66+
keepdims: Truthy,
2067
return_sign: Falsy = False,
21-
) -> np.float64 | np.complex128: ...
22-
@overload
68+
) -> onp.ArrayND[np.complex128]: ...
69+
@overload # 0d/nd T, axis=<given>
2370
def logsumexp(
24-
a: onp.ToFloatND,
71+
a: _InexactT | onp.ToArrayND[_InexactT, _InexactT],
72+
axis: AnyShape,
73+
b: onp.ToComplex | onp.ToComplexND | None = None,
74+
*,
75+
keepdims: Falsy = False,
76+
return_sign: Falsy = False,
77+
) -> onp.ArrayND[_InexactT] | Any: ...
78+
@overload # 0d/nd +float, axis=<given>
79+
def logsumexp(
80+
a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND,
81+
axis: AnyShape,
82+
b: onp.ToFloat64 | onp.ToFloat64_ND | None = None,
83+
keepdims: Falsy = False,
84+
return_sign: Falsy = False,
85+
) -> onp.ArrayND[np.float64] | Any: ...
86+
@overload # 0d/nd ~complex, axis=<given>
87+
def logsumexp(
88+
a: onp.ToJustComplex128 | onp.ToJustComplex128_ND,
89+
axis: AnyShape,
90+
b: onp.ToComplex128 | onp.ToComplex128_ND | None = None,
91+
keepdims: Falsy = False,
92+
return_sign: Falsy = False,
93+
) -> onp.ArrayND[np.complex128] | Any: ...
94+
@overload # floating fallback, return_sign=False
95+
def logsumexp(
96+
a: onp.ToFloat | onp.ToFloatND,
2597
axis: AnyShape | None = None,
2698
b: onp.ToFloat | onp.ToFloatND | None = None,
2799
keepdims: bool = False,
28100
return_sign: Falsy = False,
29-
) -> np.float64 | onp.ArrayND[np.float64]: ...
30-
@overload
101+
) -> onp.ArrayND[np.float64 | Any] | Any: ...
102+
@overload # complex fallback, return_sign=False
31103
def logsumexp(
32-
a: onp.ToComplexND,
104+
a: onp.ToComplex | onp.ToComplexND,
33105
axis: AnyShape | None = None,
34-
b: onp.ToFloat | onp.ToFloatND | None = None,
106+
b: onp.ToComplex | onp.ToComplexND | None = None,
35107
keepdims: bool = False,
36108
return_sign: Falsy = False,
37-
) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ...
38-
@overload
109+
) -> onp.ArrayND[np.complex128 | Any] | Any: ...
110+
@overload # 0d/nd T@floating, axis=None (default), keepdims=False (default), return_sign=True
111+
def logsumexp(
112+
a: _FloatingT | onp.ToArrayND[_FloatingT, _FloatingT],
113+
axis: None = None,
114+
b: onp.ToFloat | onp.ToFloatND | None = None,
115+
keepdims: Falsy = False,
116+
*,
117+
return_sign: Truthy,
118+
) -> tuple[_FloatingT, _FloatingT]: ...
119+
@overload # 0d/nd +float , axis=None (default), keepdims=False (default), return_sign=True
120+
def logsumexp(
121+
a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND,
122+
axis: None = None,
123+
b: onp.ToFloat64 | onp.ToFloat64_ND | None = None,
124+
keepdims: Falsy = False,
125+
*,
126+
return_sign: Truthy,
127+
) -> tuple[np.float64, np.float64]: ...
128+
@overload # 0d/nd ~complex, axis=None (default), keepdims=False (default), return_sign=True
129+
def logsumexp(
130+
a: onp.ToJustComplex128 | onp.ToJustComplex128_ND,
131+
axis: None = None,
132+
b: onp.ToComplex128 | onp.ToComplex128_ND | None = None,
133+
keepdims: Falsy = False,
134+
*,
135+
return_sign: Truthy,
136+
) -> tuple[np.float64, np.complex128]: ...
137+
@overload # 0d/nd T@complexfloating, axis=None (default), keepdims=False (default), return_sign=True
138+
def logsumexp(
139+
a: _CFloatingT | onp.ToArrayND[_CFloatingT, _CFloatingT],
140+
axis: None = None,
141+
b: onp.ToFloat | onp.ToFloatND | None = None,
142+
keepdims: Falsy = False,
143+
*,
144+
return_sign: Truthy,
145+
) -> tuple[npc.floating, _CFloatingT]: ...
146+
@overload # 0d/nd T@floatinv, keepdims=True, return_sign=True
147+
def logsumexp(
148+
a: _FloatingT | onp.ToArrayND[_FloatingT, _FloatingT],
149+
axis: AnyShape | None = None,
150+
b: onp.ToFloat | onp.ToFloatND | None = None,
151+
*,
152+
keepdims: Truthy,
153+
return_sign: Truthy,
154+
) -> tuple[onp.ArrayND[_FloatingT], onp.ArrayND[_FloatingT]]: ...
155+
@overload # 0d/nd +float, keepdims=True, return_sign=True
156+
def logsumexp(
157+
a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND,
158+
axis: AnyShape | None = None,
159+
b: onp.ToFloat64 | onp.ToFloat64_ND | None = None,
160+
*,
161+
keepdims: Truthy,
162+
return_sign: Truthy,
163+
) -> tuple[onp.ArrayND[np.float64], onp.ArrayND[np.float64]]: ...
164+
@overload # 0d/nd ~complex, keepdims=True, return_sign=True
165+
def logsumexp(
166+
a: onp.ToJustComplex128 | onp.ToJustComplex128_ND,
167+
axis: AnyShape | None = None,
168+
b: onp.ToComplex128 | onp.ToComplex128_ND | None = None,
169+
*,
170+
keepdims: Truthy,
171+
return_sign: Truthy,
172+
) -> tuple[onp.ArrayND[np.float64], onp.ArrayND[np.complex128]]: ...
173+
@overload # 0d/nd T@complexfloating, keepdims=True, return_sign=True
39174
def logsumexp(
40-
a: onp.ToFloat, axis: AnyShape | None = None, b: onp.ToFloat | None = None, keepdims: bool = False, *, return_sign: Truthy
41-
) -> tuple[np.float64, bool | np.bool_]: ...
42-
@overload
175+
a: _CFloatingT | onp.ToArrayND[_CFloatingT, _CFloatingT],
176+
axis: AnyShape | None = None,
177+
b: onp.ToComplex | onp.ToComplexND | None = None,
178+
*,
179+
keepdims: Truthy,
180+
return_sign: Truthy,
181+
) -> tuple[onp.ArrayND[npc.floating], onp.ArrayND[_CFloatingT]]: ...
182+
@overload # 0d/nd T@floatinv, axis=<given>, return_sign=True
43183
def logsumexp(
44-
a: onp.ToComplex, axis: AnyShape | None = None, b: onp.ToFloat | None = None, keepdims: bool = False, *, return_sign: Truthy
45-
) -> tuple[np.float64 | np.complex128, bool | np.bool_]: ...
46-
@overload
184+
a: _FloatingT | onp.ToArrayND[_FloatingT, _FloatingT],
185+
axis: AnyShape,
186+
b: onp.ToFloat | onp.ToFloatND | None = None,
187+
keepdims: Falsy = False,
188+
*,
189+
return_sign: Truthy,
190+
) -> tuple[onp.ArrayND[_FloatingT] | Any, onp.ArrayND[_FloatingT] | Any]: ...
191+
@overload # 0d/nd +float, axis=<given>, return_sign=True
47192
def logsumexp(
48-
a: onp.ToFloatND,
193+
a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND,
194+
axis: AnyShape,
195+
b: onp.ToFloat64 | onp.ToFloat64_ND | None = None,
196+
keepdims: Falsy = False,
197+
*,
198+
return_sign: Truthy,
199+
) -> tuple[onp.ArrayND[np.float64] | Any, onp.ArrayND[np.float64] | Any]: ...
200+
@overload # 0d/nd ~complex, axis=<given>, return_sign=True
201+
def logsumexp(
202+
a: onp.ToJustComplex128 | onp.ToJustComplex128_ND,
203+
axis: AnyShape,
204+
b: onp.ToComplex128 | onp.ToComplex128_ND | None = None,
205+
keepdims: Falsy = False,
206+
*,
207+
return_sign: Truthy,
208+
) -> tuple[onp.ArrayND[np.float64] | Any, onp.ArrayND[np.complex128] | Any]: ...
209+
@overload # 0d/nd T@complexfloating, axis=<given>, return_sign=True
210+
def logsumexp(
211+
a: _CFloatingT | onp.ToArrayND[_CFloatingT, _CFloatingT],
212+
axis: AnyShape,
213+
b: onp.ToComplex | onp.ToComplexND | None = None,
214+
keepdims: Falsy = False,
215+
*,
216+
return_sign: Truthy,
217+
) -> tuple[onp.ArrayND[npc.floating] | Any, onp.ArrayND[_CFloatingT] | Any]: ...
218+
@overload # floating fallback, return_sign=True
219+
def logsumexp(
220+
a: onp.ToFloat | onp.ToFloatND,
49221
axis: AnyShape | None = None,
50222
b: onp.ToFloat | onp.ToFloatND | None = None,
51223
keepdims: bool = False,
52224
*,
53225
return_sign: Truthy,
54-
) -> tuple[np.float64, bool | np.bool_] | tuple[onp.ArrayND[np.float64], onp.ArrayND[np.bool_]]: ...
55-
@overload
226+
) -> tuple[onp.ArrayND[np.float64 | Any] | Any, onp.ArrayND[np.float64 | Any] | Any]: ...
227+
@overload # complex fallback, return_sign=True
56228
def logsumexp(
57-
a: onp.ToComplexND,
229+
a: onp.ToComplex | onp.ToComplexND,
58230
axis: AnyShape | None = None,
59-
b: onp.ToFloat | onp.ToFloatND | None = None,
231+
b: onp.ToComplex | onp.ToComplexND | None = None,
60232
keepdims: bool = False,
61233
*,
62234
return_sign: Truthy,
63-
) -> (
64-
tuple[np.float64 | np.complex128, bool | np.bool_] | tuple[onp.ArrayND[np.float64 | np.complex128], onp.ArrayND[np.bool_]]
65-
): ...
235+
) -> tuple[onp.ArrayND[np.float64 | Any] | Any, onp.ArrayND[np.complex128 | Any] | Any]: ...
66236

67-
#
68-
@overload
69-
def softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64: ...
70-
@overload
71-
def softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ...
72-
@overload
73-
def softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.float64 | np.complex128: ...
74-
@overload
75-
def softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | np.complex128]: ...
237+
# NOTE: keep in sync with `log_softmax`
238+
@overload # T
239+
def softmax(x: _InexactOrArrayT, axis: AnyShape | None = None) -> _InexactOrArrayT: ... # type: ignore[overload-overlap]
240+
@overload # 0d +float64
241+
def softmax(x: onp.ToInt | onp.ToJustFloat64, axis: AnyShape | None = None) -> np.float64: ...
242+
@overload # 0d ~complex128
243+
def softmax(x: onp.ToJustComplex128, axis: AnyShape | None = None) -> np.complex128: ...
244+
@overload # nd T@inexact
245+
def softmax(x: onp.ToArrayND[_InexactT, _InexactT], axis: AnyShape | None = None) -> onp.ArrayND[_InexactT]: ...
246+
@overload # nd +float64
247+
def softmax(x: onp.ToIntND | onp.ToJustFloat64_ND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ...
248+
@overload # nd ~complex128
249+
def softmax(x: onp.ToJustComplex128_ND, axis: AnyShape | None = None) -> onp.ArrayND[np.complex128]: ...
250+
@overload # 0d float fallback
251+
def softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64 | Any: ...
252+
@overload # 0d complex fallback
253+
def softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.complex128 | Any: ...
254+
@overload # nd float fallback
255+
def softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | Any]: ...
256+
@overload # nd complex fallback
257+
def softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.complex128 | Any]: ...
76258

77-
#
78-
@overload
79-
def log_softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64: ...
80-
@overload
81-
def log_softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ...
82-
@overload
83-
def log_softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.float64 | np.complex128: ...
84-
@overload
85-
def log_softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | np.complex128]: ...
259+
# NOTE: keep in sync with `softmax`
260+
@overload # T
261+
def log_softmax(x: _InexactOrArrayT, axis: AnyShape | None = None) -> _InexactOrArrayT: ... # type: ignore[overload-overlap]
262+
@overload # 0d +float64
263+
def log_softmax(x: onp.ToInt | onp.ToJustFloat64, axis: AnyShape | None = None) -> np.float64: ...
264+
@overload # 0d ~complex128
265+
def log_softmax(x: onp.ToJustComplex128, axis: AnyShape | None = None) -> np.complex128: ...
266+
@overload # nd T@inexact
267+
def log_softmax(x: onp.ToArrayND[_InexactT, _InexactT], axis: AnyShape | None = None) -> onp.ArrayND[_InexactT]: ...
268+
@overload # nd +float64
269+
def log_softmax(x: onp.ToIntND | onp.ToJustFloat64_ND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ...
270+
@overload # nd ~complex128
271+
def log_softmax(x: onp.ToJustComplex128_ND, axis: AnyShape | None = None) -> onp.ArrayND[np.complex128]: ...
272+
@overload # 0d float fallback
273+
def log_softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64 | Any: ...
274+
@overload # 0d complex fallback
275+
def log_softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.complex128 | Any: ...
276+
@overload # nd float fallback
277+
def log_softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | Any]: ...
278+
@overload # nd complex fallback
279+
def log_softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.complex128 | Any]: ...

0 commit comments

Comments
 (0)