Skip to content

Commit a9a79f4

Browse files
committed
special: improve logsumexp, softmax and log_softmax
1 parent 7c51738 commit a9a79f4

File tree

1 file changed

+67
-19
lines changed

1 file changed

+67
-19
lines changed

scipy-stubs/special/_logsumexp.pyi

Lines changed: 67 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,98 @@
1-
from typing import Literal, overload
1+
from typing import overload
22

33
import numpy as np
44
import optype.numpy as onp
5+
from scipy._typing import AnyShape, Falsy, Truthy
56

67
__all__ = ["log_softmax", "logsumexp", "softmax"]
78

8-
# TODO: Support `return_sign=True`
99
@overload
1010
def logsumexp(
1111
a: onp.ToFloat,
12-
axis: int | tuple[int, ...] | None = None,
12+
axis: AnyShape | None = None,
1313
b: onp.ToFloat | None = None,
1414
keepdims: bool = False,
15-
return_sign: Literal[False, 0] = False,
15+
return_sign: Falsy = False,
1616
) -> np.float64: ...
1717
@overload
1818
def logsumexp(
1919
a: onp.ToComplex,
20-
axis: int | tuple[int, ...] | None = None,
20+
axis: AnyShape | None = None,
2121
b: onp.ToFloat | None = None,
2222
keepdims: bool = False,
23-
return_sign: Literal[False, 0] = False,
23+
return_sign: Falsy = False,
2424
) -> np.float64 | np.complex128: ...
2525
@overload
2626
def logsumexp(
2727
a: onp.ToFloatND,
28-
axis: int | tuple[int, ...],
28+
axis: AnyShape,
2929
b: onp.ToFloat | onp.ToFloatND | None = None,
3030
keepdims: bool = False,
31-
return_sign: Literal[False, 0] = False,
31+
return_sign: Falsy = False,
3232
) -> np.float64 | onp.ArrayND[np.float64]: ...
3333
@overload
3434
def logsumexp(
3535
a: onp.ToComplexND,
36-
axis: int | tuple[int, ...],
36+
axis: AnyShape,
3737
b: onp.ToFloat | onp.ToFloatND | None = None,
3838
keepdims: bool = False,
39-
return_sign: Literal[False, 0] = False,
39+
return_sign: Falsy = False,
4040
) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ...
41+
@overload
42+
def logsumexp(
43+
a: onp.ToFloat,
44+
axis: AnyShape | None = None,
45+
b: onp.ToFloat | None = None,
46+
keepdims: bool = False,
47+
*,
48+
return_sign: Truthy,
49+
) -> tuple[np.float64, bool | np.bool_]: ...
50+
@overload
51+
def logsumexp(
52+
a: onp.ToComplex,
53+
axis: AnyShape | None = None,
54+
b: onp.ToFloat | None = None,
55+
keepdims: bool = False,
56+
*,
57+
return_sign: Truthy,
58+
) -> tuple[np.float64 | np.complex128, bool | np.bool_]: ...
59+
@overload
60+
def logsumexp(
61+
a: onp.ToFloatND,
62+
axis: AnyShape,
63+
b: onp.ToFloat | onp.ToFloatND | None = None,
64+
keepdims: bool = False,
65+
*,
66+
return_sign: Truthy,
67+
) -> tuple[np.float64, bool | np.bool_] | tuple[onp.ArrayND[np.float64], onp.ArrayND[np.bool_]]: ...
68+
@overload
69+
def logsumexp(
70+
a: onp.ToComplexND,
71+
axis: AnyShape,
72+
b: onp.ToFloat | onp.ToFloatND | None = None,
73+
keepdims: bool = False,
74+
*,
75+
return_sign: Truthy,
76+
) -> (
77+
tuple[np.float64 | np.complex128, bool | np.bool_] | tuple[onp.ArrayND[np.float64 | np.complex128], onp.ArrayND[np.bool_]]
78+
): ...
4179

42-
# TODO: Overload real/complex and scalar/array
43-
def softmax(
44-
x: onp.ToComplex | onp.ToComplexND,
45-
axis: int | tuple[int, ...] | None = None,
46-
) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ...
47-
def log_softmax(
48-
x: onp.ToComplex | onp.ToComplexND,
49-
axis: int | tuple[int, ...] | None = None,
50-
) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ...
80+
#
81+
@overload
82+
def softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64: ...
83+
@overload
84+
def softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ...
85+
@overload
86+
def softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.float64 | np.complex128: ...
87+
@overload
88+
def softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | np.complex128]: ...
89+
90+
#
91+
@overload
92+
def log_softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64: ...
93+
@overload
94+
def log_softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ...
95+
@overload
96+
def log_softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.float64 | np.complex128: ...
97+
@overload
98+
def log_softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | np.complex128]: ...

0 commit comments

Comments
 (0)