|
1 |
| -from typing import overload |
| 1 | +from typing import Any, TypeVar, overload |
2 | 2 |
|
3 | 3 | import numpy as np
|
4 | 4 | import optype.numpy as onp
|
| 5 | +import optype.numpy.compat as npc |
5 | 6 |
|
6 | 7 | from scipy._typing import AnyShape, Falsy, Truthy
|
7 | 8 |
|
8 | 9 | __all__ = ["log_softmax", "logsumexp", "softmax"]
|
9 | 10 |
|
10 |
| -@overload |
| 11 | +_InexactT = TypeVar("_InexactT", bound=npc.inexact) |
| 12 | +_FloatingT = TypeVar("_FloatingT", bound=npc.floating) |
| 13 | +_CFloatingT = TypeVar("_CFloatingT", bound=npc.complexfloating) |
| 14 | +_InexactOrArrayT = TypeVar("_InexactOrArrayT", bound=npc.inexact | onp.ArrayND[npc.inexact]) |
| 15 | + |
| 16 | +### |
| 17 | + |
| 18 | +@overload # 0d/nd T, axis=None (default), keepdims=False (default) |
11 | 19 | def logsumexp(
|
12 |
| - a: onp.ToFloat, axis: AnyShape | None = None, b: onp.ToFloat | None = None, keepdims: bool = False, return_sign: Falsy = False |
| 20 | + a: _InexactT | onp.ToArrayND[_InexactT, _InexactT], |
| 21 | + axis: None = None, |
| 22 | + b: onp.ToComplex | onp.ToComplexND | None = None, |
| 23 | + keepdims: Falsy = False, |
| 24 | + return_sign: Falsy = False, |
| 25 | +) -> _InexactT: ... |
| 26 | +@overload # 0d/nd +float , axis=None (default), keepdims=False (default) |
| 27 | +def logsumexp( |
| 28 | + a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND, |
| 29 | + axis: None = None, |
| 30 | + b: onp.ToFloat64 | onp.ToFloat64_ND | None = None, |
| 31 | + keepdims: Falsy = False, |
| 32 | + return_sign: Falsy = False, |
13 | 33 | ) -> np.float64: ...
|
14 |
| -@overload |
| 34 | +@overload # 0d/nd ~complex, axis=None (default), keepdims=False (default) |
15 | 35 | def logsumexp(
|
16 |
| - a: onp.ToComplex, |
| 36 | + a: onp.ToJustComplex128 | onp.ToJustComplex128_ND, |
| 37 | + axis: None = None, |
| 38 | + b: onp.ToComplex128 | onp.ToComplex128_ND | None = None, |
| 39 | + keepdims: Falsy = False, |
| 40 | + return_sign: Falsy = False, |
| 41 | +) -> np.complex128: ... |
| 42 | +@overload # 0d/nd T, keepdims=True |
| 43 | +def logsumexp( |
| 44 | + a: _InexactT | onp.ToArrayND[_InexactT, _InexactT], |
17 | 45 | axis: AnyShape | None = None,
|
18 |
| - b: onp.ToFloat | None = None, |
19 |
| - keepdims: bool = False, |
| 46 | + b: onp.ToComplex | onp.ToComplexND | None = None, |
| 47 | + *, |
| 48 | + keepdims: Truthy, |
| 49 | + return_sign: Falsy = False, |
| 50 | +) -> onp.ArrayND[_InexactT]: ... |
| 51 | +@overload # 0d/nd +float, keepdims=True |
| 52 | +def logsumexp( |
| 53 | + a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND, |
| 54 | + axis: AnyShape | None = None, |
| 55 | + b: onp.ToFloat64 | onp.ToFloat64_ND | None = None, |
| 56 | + *, |
| 57 | + keepdims: Truthy, |
| 58 | + return_sign: Falsy = False, |
| 59 | +) -> onp.ArrayND[np.float64]: ... |
| 60 | +@overload # 0d/nd ~complex, keepdims=True |
| 61 | +def logsumexp( |
| 62 | + a: onp.ToJustComplex128 | onp.ToJustComplex128_ND, |
| 63 | + axis: AnyShape | None = None, |
| 64 | + b: onp.ToComplex128 | onp.ToComplex128_ND | None = None, |
| 65 | + *, |
| 66 | + keepdims: Truthy, |
20 | 67 | return_sign: Falsy = False,
|
21 |
| -) -> np.float64 | np.complex128: ... |
22 |
| -@overload |
| 68 | +) -> onp.ArrayND[np.complex128]: ... |
| 69 | +@overload # 0d/nd T, axis=<given> |
23 | 70 | def logsumexp(
|
24 |
| - a: onp.ToFloatND, |
| 71 | + a: _InexactT | onp.ToArrayND[_InexactT, _InexactT], |
| 72 | + axis: AnyShape, |
| 73 | + b: onp.ToComplex | onp.ToComplexND | None = None, |
| 74 | + *, |
| 75 | + keepdims: Falsy = False, |
| 76 | + return_sign: Falsy = False, |
| 77 | +) -> onp.ArrayND[_InexactT] | Any: ... |
| 78 | +@overload # 0d/nd +float, axis=<given> |
| 79 | +def logsumexp( |
| 80 | + a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND, |
| 81 | + axis: AnyShape, |
| 82 | + b: onp.ToFloat64 | onp.ToFloat64_ND | None = None, |
| 83 | + keepdims: Falsy = False, |
| 84 | + return_sign: Falsy = False, |
| 85 | +) -> onp.ArrayND[np.float64] | Any: ... |
| 86 | +@overload # 0d/nd ~complex, axis=<given> |
| 87 | +def logsumexp( |
| 88 | + a: onp.ToJustComplex128 | onp.ToJustComplex128_ND, |
| 89 | + axis: AnyShape, |
| 90 | + b: onp.ToComplex128 | onp.ToComplex128_ND | None = None, |
| 91 | + keepdims: Falsy = False, |
| 92 | + return_sign: Falsy = False, |
| 93 | +) -> onp.ArrayND[np.complex128] | Any: ... |
| 94 | +@overload # floating fallback, return_sign=False |
| 95 | +def logsumexp( |
| 96 | + a: onp.ToFloat | onp.ToFloatND, |
25 | 97 | axis: AnyShape | None = None,
|
26 | 98 | b: onp.ToFloat | onp.ToFloatND | None = None,
|
27 | 99 | keepdims: bool = False,
|
28 | 100 | return_sign: Falsy = False,
|
29 |
| -) -> np.float64 | onp.ArrayND[np.float64]: ... |
30 |
| -@overload |
| 101 | +) -> onp.ArrayND[np.float64 | Any] | Any: ... |
| 102 | +@overload # complex fallback, return_sign=False |
31 | 103 | def logsumexp(
|
32 |
| - a: onp.ToComplexND, |
| 104 | + a: onp.ToComplex | onp.ToComplexND, |
33 | 105 | axis: AnyShape | None = None,
|
34 |
| - b: onp.ToFloat | onp.ToFloatND | None = None, |
| 106 | + b: onp.ToComplex | onp.ToComplexND | None = None, |
35 | 107 | keepdims: bool = False,
|
36 | 108 | return_sign: Falsy = False,
|
37 |
| -) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ... |
38 |
| -@overload |
| 109 | +) -> onp.ArrayND[np.complex128 | Any] | Any: ... |
| 110 | +@overload # 0d/nd T@floating, axis=None (default), keepdims=False (default), return_sign=True |
| 111 | +def logsumexp( |
| 112 | + a: _FloatingT | onp.ToArrayND[_FloatingT, _FloatingT], |
| 113 | + axis: None = None, |
| 114 | + b: onp.ToFloat | onp.ToFloatND | None = None, |
| 115 | + keepdims: Falsy = False, |
| 116 | + *, |
| 117 | + return_sign: Truthy, |
| 118 | +) -> tuple[_FloatingT, _FloatingT]: ... |
| 119 | +@overload # 0d/nd +float , axis=None (default), keepdims=False (default), return_sign=True |
| 120 | +def logsumexp( |
| 121 | + a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND, |
| 122 | + axis: None = None, |
| 123 | + b: onp.ToFloat64 | onp.ToFloat64_ND | None = None, |
| 124 | + keepdims: Falsy = False, |
| 125 | + *, |
| 126 | + return_sign: Truthy, |
| 127 | +) -> tuple[np.float64, np.float64]: ... |
| 128 | +@overload # 0d/nd ~complex, axis=None (default), keepdims=False (default), return_sign=True |
| 129 | +def logsumexp( |
| 130 | + a: onp.ToJustComplex128 | onp.ToJustComplex128_ND, |
| 131 | + axis: None = None, |
| 132 | + b: onp.ToComplex128 | onp.ToComplex128_ND | None = None, |
| 133 | + keepdims: Falsy = False, |
| 134 | + *, |
| 135 | + return_sign: Truthy, |
| 136 | +) -> tuple[np.float64, np.complex128]: ... |
| 137 | +@overload # 0d/nd T@complexfloating, axis=None (default), keepdims=False (default), return_sign=True |
| 138 | +def logsumexp( |
| 139 | + a: _CFloatingT | onp.ToArrayND[_CFloatingT, _CFloatingT], |
| 140 | + axis: None = None, |
| 141 | + b: onp.ToFloat | onp.ToFloatND | None = None, |
| 142 | + keepdims: Falsy = False, |
| 143 | + *, |
| 144 | + return_sign: Truthy, |
| 145 | +) -> tuple[npc.floating, _CFloatingT]: ... |
| 146 | +@overload # 0d/nd T@floatinv, keepdims=True, return_sign=True |
| 147 | +def logsumexp( |
| 148 | + a: _FloatingT | onp.ToArrayND[_FloatingT, _FloatingT], |
| 149 | + axis: AnyShape | None = None, |
| 150 | + b: onp.ToFloat | onp.ToFloatND | None = None, |
| 151 | + *, |
| 152 | + keepdims: Truthy, |
| 153 | + return_sign: Truthy, |
| 154 | +) -> tuple[onp.ArrayND[_FloatingT], onp.ArrayND[_FloatingT]]: ... |
| 155 | +@overload # 0d/nd +float, keepdims=True, return_sign=True |
| 156 | +def logsumexp( |
| 157 | + a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND, |
| 158 | + axis: AnyShape | None = None, |
| 159 | + b: onp.ToFloat64 | onp.ToFloat64_ND | None = None, |
| 160 | + *, |
| 161 | + keepdims: Truthy, |
| 162 | + return_sign: Truthy, |
| 163 | +) -> tuple[onp.ArrayND[np.float64], onp.ArrayND[np.float64]]: ... |
| 164 | +@overload # 0d/nd ~complex, keepdims=True, return_sign=True |
| 165 | +def logsumexp( |
| 166 | + a: onp.ToJustComplex128 | onp.ToJustComplex128_ND, |
| 167 | + axis: AnyShape | None = None, |
| 168 | + b: onp.ToComplex128 | onp.ToComplex128_ND | None = None, |
| 169 | + *, |
| 170 | + keepdims: Truthy, |
| 171 | + return_sign: Truthy, |
| 172 | +) -> tuple[onp.ArrayND[np.float64], onp.ArrayND[np.complex128]]: ... |
| 173 | +@overload # 0d/nd T@complexfloating, keepdims=True, return_sign=True |
39 | 174 | def logsumexp(
|
40 |
| - a: onp.ToFloat, axis: AnyShape | None = None, b: onp.ToFloat | None = None, keepdims: bool = False, *, return_sign: Truthy |
41 |
| -) -> tuple[np.float64, bool | np.bool_]: ... |
42 |
| -@overload |
| 175 | + a: _CFloatingT | onp.ToArrayND[_CFloatingT, _CFloatingT], |
| 176 | + axis: AnyShape | None = None, |
| 177 | + b: onp.ToComplex | onp.ToComplexND | None = None, |
| 178 | + *, |
| 179 | + keepdims: Truthy, |
| 180 | + return_sign: Truthy, |
| 181 | +) -> tuple[onp.ArrayND[npc.floating], onp.ArrayND[_CFloatingT]]: ... |
| 182 | +@overload # 0d/nd T@floatinv, axis=<given>, return_sign=True |
43 | 183 | def logsumexp(
|
44 |
| - a: onp.ToComplex, axis: AnyShape | None = None, b: onp.ToFloat | None = None, keepdims: bool = False, *, return_sign: Truthy |
45 |
| -) -> tuple[np.float64 | np.complex128, bool | np.bool_]: ... |
46 |
| -@overload |
| 184 | + a: _FloatingT | onp.ToArrayND[_FloatingT, _FloatingT], |
| 185 | + axis: AnyShape, |
| 186 | + b: onp.ToFloat | onp.ToFloatND | None = None, |
| 187 | + keepdims: Falsy = False, |
| 188 | + *, |
| 189 | + return_sign: Truthy, |
| 190 | +) -> tuple[onp.ArrayND[_FloatingT] | Any, onp.ArrayND[_FloatingT] | Any]: ... |
| 191 | +@overload # 0d/nd +float, axis=<given>, return_sign=True |
47 | 192 | def logsumexp(
|
48 |
| - a: onp.ToFloatND, |
| 193 | + a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND, |
| 194 | + axis: AnyShape, |
| 195 | + b: onp.ToFloat64 | onp.ToFloat64_ND | None = None, |
| 196 | + keepdims: Falsy = False, |
| 197 | + *, |
| 198 | + return_sign: Truthy, |
| 199 | +) -> tuple[onp.ArrayND[np.float64] | Any, onp.ArrayND[np.float64] | Any]: ... |
| 200 | +@overload # 0d/nd ~complex, axis=<given>, return_sign=True |
| 201 | +def logsumexp( |
| 202 | + a: onp.ToJustComplex128 | onp.ToJustComplex128_ND, |
| 203 | + axis: AnyShape, |
| 204 | + b: onp.ToComplex128 | onp.ToComplex128_ND | None = None, |
| 205 | + keepdims: Falsy = False, |
| 206 | + *, |
| 207 | + return_sign: Truthy, |
| 208 | +) -> tuple[onp.ArrayND[np.float64] | Any, onp.ArrayND[np.complex128] | Any]: ... |
| 209 | +@overload # 0d/nd T@complexfloating, axis=<given>, return_sign=True |
| 210 | +def logsumexp( |
| 211 | + a: _CFloatingT | onp.ToArrayND[_CFloatingT, _CFloatingT], |
| 212 | + axis: AnyShape, |
| 213 | + b: onp.ToComplex | onp.ToComplexND | None = None, |
| 214 | + keepdims: Falsy = False, |
| 215 | + *, |
| 216 | + return_sign: Truthy, |
| 217 | +) -> tuple[onp.ArrayND[npc.floating] | Any, onp.ArrayND[_CFloatingT] | Any]: ... |
| 218 | +@overload # floating fallback, return_sign=True |
| 219 | +def logsumexp( |
| 220 | + a: onp.ToFloat | onp.ToFloatND, |
49 | 221 | axis: AnyShape | None = None,
|
50 | 222 | b: onp.ToFloat | onp.ToFloatND | None = None,
|
51 | 223 | keepdims: bool = False,
|
52 | 224 | *,
|
53 | 225 | return_sign: Truthy,
|
54 |
| -) -> tuple[np.float64, bool | np.bool_] | tuple[onp.ArrayND[np.float64], onp.ArrayND[np.bool_]]: ... |
55 |
| -@overload |
| 226 | +) -> tuple[onp.ArrayND[np.float64 | Any] | Any, onp.ArrayND[np.float64 | Any] | Any]: ... |
| 227 | +@overload # complex fallback, return_sign=True |
56 | 228 | def logsumexp(
|
57 |
| - a: onp.ToComplexND, |
| 229 | + a: onp.ToComplex | onp.ToComplexND, |
58 | 230 | axis: AnyShape | None = None,
|
59 |
| - b: onp.ToFloat | onp.ToFloatND | None = None, |
| 231 | + b: onp.ToComplex | onp.ToComplexND | None = None, |
60 | 232 | keepdims: bool = False,
|
61 | 233 | *,
|
62 | 234 | return_sign: Truthy,
|
63 |
| -) -> ( |
64 |
| - tuple[np.float64 | np.complex128, bool | np.bool_] | tuple[onp.ArrayND[np.float64 | np.complex128], onp.ArrayND[np.bool_]] |
65 |
| -): ... |
| 235 | +) -> tuple[onp.ArrayND[np.float64 | Any] | Any, onp.ArrayND[np.complex128 | Any] | Any]: ... |
66 | 236 |
|
67 |
| -# |
68 |
| -@overload |
69 |
| -def softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64: ... |
70 |
| -@overload |
71 |
| -def softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ... |
72 |
| -@overload |
73 |
| -def softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.float64 | np.complex128: ... |
74 |
| -@overload |
75 |
| -def softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | np.complex128]: ... |
| 237 | +# NOTE: keep in sync with `log_softmax` |
| 238 | +@overload # T |
| 239 | +def softmax(x: _InexactOrArrayT, axis: AnyShape | None = None) -> _InexactOrArrayT: ... # type: ignore[overload-overlap] |
| 240 | +@overload # 0d +float64 |
| 241 | +def softmax(x: onp.ToInt | onp.ToJustFloat64, axis: AnyShape | None = None) -> np.float64: ... |
| 242 | +@overload # 0d ~complex128 |
| 243 | +def softmax(x: onp.ToJustComplex128, axis: AnyShape | None = None) -> np.complex128: ... |
| 244 | +@overload # nd T@inexact |
| 245 | +def softmax(x: onp.ToArrayND[_InexactT, _InexactT], axis: AnyShape | None = None) -> onp.ArrayND[_InexactT]: ... |
| 246 | +@overload # nd +float64 |
| 247 | +def softmax(x: onp.ToIntND | onp.ToJustFloat64_ND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ... |
| 248 | +@overload # nd ~complex128 |
| 249 | +def softmax(x: onp.ToJustComplex128_ND, axis: AnyShape | None = None) -> onp.ArrayND[np.complex128]: ... |
| 250 | +@overload # 0d float fallback |
| 251 | +def softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64 | Any: ... |
| 252 | +@overload # 0d complex fallback |
| 253 | +def softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.complex128 | Any: ... |
| 254 | +@overload # nd float fallback |
| 255 | +def softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | Any]: ... |
| 256 | +@overload # nd complex fallback |
| 257 | +def softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.complex128 | Any]: ... |
76 | 258 |
|
77 |
| -# |
78 |
| -@overload |
79 |
| -def log_softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64: ... |
80 |
| -@overload |
81 |
| -def log_softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ... |
82 |
| -@overload |
83 |
| -def log_softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.float64 | np.complex128: ... |
84 |
| -@overload |
85 |
| -def log_softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | np.complex128]: ... |
| 259 | +# NOTE: keep in sync with `softmax` |
| 260 | +@overload # T |
| 261 | +def log_softmax(x: _InexactOrArrayT, axis: AnyShape | None = None) -> _InexactOrArrayT: ... # type: ignore[overload-overlap] |
| 262 | +@overload # 0d +float64 |
| 263 | +def log_softmax(x: onp.ToInt | onp.ToJustFloat64, axis: AnyShape | None = None) -> np.float64: ... |
| 264 | +@overload # 0d ~complex128 |
| 265 | +def log_softmax(x: onp.ToJustComplex128, axis: AnyShape | None = None) -> np.complex128: ... |
| 266 | +@overload # nd T@inexact |
| 267 | +def log_softmax(x: onp.ToArrayND[_InexactT, _InexactT], axis: AnyShape | None = None) -> onp.ArrayND[_InexactT]: ... |
| 268 | +@overload # nd +float64 |
| 269 | +def log_softmax(x: onp.ToIntND | onp.ToJustFloat64_ND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64]: ... |
| 270 | +@overload # nd ~complex128 |
| 271 | +def log_softmax(x: onp.ToJustComplex128_ND, axis: AnyShape | None = None) -> onp.ArrayND[np.complex128]: ... |
| 272 | +@overload # 0d float fallback |
| 273 | +def log_softmax(x: onp.ToFloat, axis: AnyShape | None = None) -> np.float64 | Any: ... |
| 274 | +@overload # 0d complex fallback |
| 275 | +def log_softmax(x: onp.ToComplex, axis: AnyShape | None = None) -> np.complex128 | Any: ... |
| 276 | +@overload # nd float fallback |
| 277 | +def log_softmax(x: onp.ToFloatND, axis: AnyShape | None = None) -> onp.ArrayND[np.float64 | Any]: ... |
| 278 | +@overload # nd complex fallback |
| 279 | +def log_softmax(x: onp.ToComplexND, axis: AnyShape | None = None) -> onp.ArrayND[np.complex128 | Any]: ... |
0 commit comments