Skip to content

Commit b721780

Browse files
committed
special: improved [log_]softmax annotations
1 parent b3b1d4e commit b721780

File tree

2 files changed

+59
-21
lines changed

2 files changed

+59
-21
lines changed

scipy-stubs/special/_logsumexp.pyi

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ __all__ = ["log_softmax", "logsumexp", "softmax"]
1111
_InexactT = TypeVar("_InexactT", bound=npc.inexact)
1212
_FloatingT = TypeVar("_FloatingT", bound=npc.floating)
1313
_CFloatingT = TypeVar("_CFloatingT", bound=npc.complexfloating)
14+
_InexactOrArrayT = TypeVar("_InexactOrArrayT", bound=npc.inexact | onp.ArrayND[npc.inexact])
1415

1516
###
1617

@@ -98,7 +99,7 @@ def logsumexp(
9899
keepdims: bool = False,
99100
return_sign: Falsy = False,
100101
) -> onp.ArrayND[np.float64 | Any] | Any: ...
101-
@overload # ccomplex fallback, return_sign=False
102+
@overload # complex fallback, return_sign=False
102103
def logsumexp(
103104
a: onp.ToComplex | onp.ToComplexND,
104105
axis: AnyShape | None = None,
@@ -223,7 +224,7 @@ def logsumexp(
223224
*,
224225
return_sign: Truthy,
225226
) -> tuple[onp.ArrayND[np.float64 | Any] | Any, onp.ArrayND[np.float64 | Any] | Any]: ...
226-
@overload # ccomplex fallback, return_sign=True
227+
@overload # complex fallback, return_sign=True
227228
def logsumexp(
228229
a: onp.ToComplex | onp.ToComplexND,
229230
axis: AnyShape | None = None,
@@ -233,22 +234,46 @@ def logsumexp(
233234
return_sign: Truthy,
234235
) -> tuple[onp.ArrayND[np.float64 | Any] | Any, onp.ArrayND[np.complex128 | Any] | Any]: ...
235236

236-
#
237-
@overload
238-
def softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64: ...
239-
@overload
240-
def softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ...
241-
@overload
242-
def softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.float64 | np.complex128: ...
243-
@overload
244-
def softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | np.complex128]: ...
237+
# NOTE: keep in sync with `log_softmax`
238+
@overload # T
239+
def softmax(x: _InexactOrArrayT, axis: AnyShape | None = None) -> _InexactOrArrayT: ... # type: ignore[overload-overlap]
240+
@overload # 0d +float64
241+
def softmax(x: onp.ToInt | onp.ToJustFloat64, axis: AnyShape | None = None) -> np.float64: ...
242+
@overload # 0d ~complex128
243+
def softmax(x: onp.ToJustComplex128, axis: AnyShape | None = None) -> np.complex128: ...
244+
@overload # nd T@inexact
245+
def softmax(x: onp.ToArrayND[_InexactT, _InexactT], axis: AnyShape | None = None) -> onp.ArrayND[_InexactT]: ...
246+
@overload # nd +float64
247+
def softmax(x: onp.ToIntND | onp.ToJustFloat64_ND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ...
248+
@overload # nd ~complex128
249+
def softmax(x: onp.ToJustComplex128_ND, axis: AnyShape | None = None) -> onp.ArrayND[np.complex128]: ...
250+
@overload # 0d float fallback
251+
def softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64 | Any: ...
252+
@overload # 0d complex fallback
253+
def softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.complex128 | Any: ...
254+
@overload # nd float fallback
255+
def softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | Any]: ...
256+
@overload # nd complex fallback
257+
def softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.complex128 | Any]: ...
245258

246-
#
247-
@overload
248-
def log_softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64: ...
249-
@overload
250-
def log_softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ...
251-
@overload
252-
def log_softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.float64 | np.complex128: ...
253-
@overload
254-
def log_softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | np.complex128]: ...
259+
# NOTE: keep in sync with `softmax`
260+
@overload # T
261+
def log_softmax(x: _InexactOrArrayT, axis: AnyShape | None = None) -> _InexactOrArrayT: ... # type: ignore[overload-overlap]
262+
@overload # 0d +float64
263+
def log_softmax(x: onp.ToInt | onp.ToJustFloat64, axis: AnyShape | None = None) -> np.float64: ...
264+
@overload # 0d ~complex128
265+
def log_softmax(x: onp.ToJustComplex128, axis: AnyShape | None = None) -> np.complex128: ...
266+
@overload # nd T@inexact
267+
def log_softmax(x: onp.ToArrayND[_InexactT, _InexactT], axis: AnyShape | None = None) -> onp.ArrayND[_InexactT]: ...
268+
@overload # nd +float64
269+
def log_softmax(x: onp.ToIntND | onp.ToJustFloat64_ND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ...
270+
@overload # nd ~complex128
271+
def log_softmax(x: onp.ToJustComplex128_ND, axis: AnyShape | None = None) -> onp.ArrayND[np.complex128]: ...
272+
@overload # 0d float fallback
273+
def log_softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64 | Any: ...
274+
@overload # 0d complex fallback
275+
def log_softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.complex128 | Any: ...
276+
@overload # nd float fallback
277+
def log_softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | Any]: ...
278+
@overload # nd complex fallback
279+
def log_softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.complex128 | Any]: ...

tests/special/test_logsumexp.pyi

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ from typing import Any, assert_type
33
import numpy as np
44
import numpy.typing as npt
55

6-
from scipy.special import logsumexp
6+
from scipy.special import logsumexp, softmax
77

88
py_f_0d: float
99
py_c_0d: complex
@@ -75,3 +75,16 @@ assert_type(
7575
assert_type(
7676
logsumexp(c64_1d, axis=0, return_sign=True), tuple[npt.NDArray[np.floating[Any]] | Any, npt.NDArray[np.complex64] | Any]
7777
)
78+
79+
###
80+
# softmax (equiv log_softmax)
81+
82+
assert_type(softmax(py_f_0d), np.float64)
83+
assert_type(softmax(py_c_0d), np.complex128)
84+
assert_type(softmax(f16_0d), np.float16)
85+
assert_type(softmax(c64_0d), np.complex64)
86+
87+
assert_type(softmax(py_f_1d), npt.NDArray[np.float64])
88+
assert_type(softmax(py_c_1d), npt.NDArray[np.complex128])
89+
assert_type(softmax(f16_1d), np.ndarray[tuple[int], np.dtype[np.float16]])
90+
assert_type(softmax(c64_1d), np.ndarray[tuple[int], np.dtype[np.complex64]])

0 commit comments

Comments
 (0)