|
1 |
| -from typing import Literal, overload |
| 1 | +from typing import overload |
2 | 2 |
|
3 | 3 | import numpy as np
|
4 | 4 | import optype.numpy as onp
|
| 5 | +from scipy._typing import AnyShape, Falsy, Truthy |
5 | 6 |
|
6 | 7 | __all__ = ["log_softmax", "logsumexp", "softmax"]
|
7 | 8 |
|
8 |
| -# TODO: Support `return_sign=True` |
9 | 9 | @overload
|
10 | 10 | def logsumexp(
|
11 | 11 | a: onp.ToFloat,
|
12 |
| - axis: int | tuple[int, ...] | None = None, |
| 12 | + axis: AnyShape | None = None, |
13 | 13 | b: onp.ToFloat | None = None,
|
14 | 14 | keepdims: bool = False,
|
15 |
| - return_sign: Literal[False, 0] = False, |
| 15 | + return_sign: Falsy = False, |
16 | 16 | ) -> np.float64: ...
|
17 | 17 | @overload
|
18 | 18 | def logsumexp(
|
19 | 19 | a: onp.ToComplex,
|
20 |
| - axis: int | tuple[int, ...] | None = None, |
| 20 | + axis: AnyShape | None = None, |
21 | 21 | b: onp.ToFloat | None = None,
|
22 | 22 | keepdims: bool = False,
|
23 |
| - return_sign: Literal[False, 0] = False, |
| 23 | + return_sign: Falsy = False, |
24 | 24 | ) -> np.float64 | np.complex128: ...
|
25 | 25 | @overload
|
26 | 26 | def logsumexp(
|
27 | 27 | a: onp.ToFloatND,
|
28 |
| - axis: int | tuple[int, ...], |
| 28 | + axis: AnyShape, |
29 | 29 | b: onp.ToFloat | onp.ToFloatND | None = None,
|
30 | 30 | keepdims: bool = False,
|
31 |
| - return_sign: Literal[False, 0] = False, |
| 31 | + return_sign: Falsy = False, |
32 | 32 | ) -> np.float64 | onp.ArrayND[np.float64]: ...
|
33 | 33 | @overload
|
34 | 34 | def logsumexp(
|
35 | 35 | a: onp.ToComplexND,
|
36 |
| - axis: int | tuple[int, ...], |
| 36 | + axis: AnyShape, |
37 | 37 | b: onp.ToFloat | onp.ToFloatND | None = None,
|
38 | 38 | keepdims: bool = False,
|
39 |
| - return_sign: Literal[False, 0] = False, |
| 39 | + return_sign: Falsy = False, |
40 | 40 | ) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ...
|
| 41 | +@overload |
| 42 | +def logsumexp( |
| 43 | + a: onp.ToFloat, |
| 44 | + axis: AnyShape | None = None, |
| 45 | + b: onp.ToFloat | None = None, |
| 46 | + keepdims: bool = False, |
| 47 | + *, |
| 48 | + return_sign: Truthy, |
| 49 | +) -> tuple[np.float64, bool | np.bool_]: ... |
| 50 | +@overload |
| 51 | +def logsumexp( |
| 52 | + a: onp.ToComplex, |
| 53 | + axis: AnyShape | None = None, |
| 54 | + b: onp.ToFloat | None = None, |
| 55 | + keepdims: bool = False, |
| 56 | + *, |
| 57 | + return_sign: Truthy, |
| 58 | +) -> tuple[np.float64 | np.complex128, bool | np.bool_]: ... |
| 59 | +@overload |
| 60 | +def logsumexp( |
| 61 | + a: onp.ToFloatND, |
| 62 | + axis: AnyShape, |
| 63 | + b: onp.ToFloat | onp.ToFloatND | None = None, |
| 64 | + keepdims: bool = False, |
| 65 | + *, |
| 66 | + return_sign: Truthy, |
| 67 | +) -> tuple[np.float64, bool | np.bool_] | tuple[onp.ArrayND[np.float64], onp.ArrayND[np.bool_]]: ... |
| 68 | +@overload |
| 69 | +def logsumexp( |
| 70 | + a: onp.ToComplexND, |
| 71 | + axis: AnyShape, |
| 72 | + b: onp.ToFloat | onp.ToFloatND | None = None, |
| 73 | + keepdims: bool = False, |
| 74 | + *, |
| 75 | + return_sign: Truthy, |
| 76 | +) -> ( |
| 77 | + tuple[np.float64 | np.complex128, bool | np.bool_] | tuple[onp.ArrayND[np.float64 | np.complex128], onp.ArrayND[np.bool_]] |
| 78 | +): ... |
41 | 79 |
|
42 |
| -# TODO: Overload real/complex and scalar/array |
43 |
| -def softmax( |
44 |
| - x: onp.ToComplex | onp.ToComplexND, |
45 |
| - axis: int | tuple[int, ...] | None = None, |
46 |
| -) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ... |
47 |
| -def log_softmax( |
48 |
| - x: onp.ToComplex | onp.ToComplexND, |
49 |
| - axis: int | tuple[int, ...] | None = None, |
50 |
| -) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ... |
| 80 | +# |
| 81 | +@overload |
| 82 | +def softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64: ... |
| 83 | +@overload |
| 84 | +def softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ... |
| 85 | +@overload |
| 86 | +def softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.float64 | np.complex128: ... |
| 87 | +@overload |
| 88 | +def softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | np.complex128]: ... |
| 89 | + |
| 90 | +# |
| 91 | +@overload |
| 92 | +def log_softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64: ... |
| 93 | +@overload |
| 94 | +def log_softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ... |
| 95 | +@overload |
| 96 | +def log_softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.float64 | np.complex128: ... |
| 97 | +@overload |
| 98 | +def log_softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | np.complex128]: ... |
0 commit comments