|
1 |
| -from typing import overload |
| 1 | +from typing import Any, TypeVar, overload |
2 | 2 |
|
3 | 3 | import numpy as np
|
4 | 4 | import optype.numpy as onp
|
| 5 | +import optype.numpy.compat as npc |
5 | 6 |
|
6 | 7 | from scipy._typing import AnyShape, Falsy, Truthy
|
7 | 8 |
|
8 | 9 | __all__ = ["log_softmax", "logsumexp", "softmax"]
|
9 | 10 |
|
10 |
| -@overload |
| 11 | +_InexactT = TypeVar("_InexactT", bound=npc.inexact) |
| 12 | +_FloatingT = TypeVar("_FloatingT", bound=npc.floating) |
| 13 | +_CFloatingT = TypeVar("_CFloatingT", bound=npc.complexfloating) |
| 14 | + |
| 15 | +### |
| 16 | + |
| 17 | +@overload # 0d/nd T, axis=None (default), keepdims=False (default) |
| 18 | +def logsumexp( |
| 19 | + a: _InexactT | onp.ToArrayND[_InexactT, _InexactT], |
| 20 | + axis: None = None, |
| 21 | + b: onp.ToComplex | onp.ToComplexND | None = None, |
| 22 | + keepdims: Falsy = False, |
| 23 | + return_sign: Falsy = False, |
| 24 | +) -> _InexactT: ... |
| 25 | +@overload # 0d/nd +float , axis=None (default), keepdims=False (default) |
11 | 26 | def logsumexp(
|
12 |
| - a: onp.ToFloat, axis: AnyShape | None = None, b: onp.ToFloat | None = None, keepdims: bool = False, return_sign: Falsy = False |
| 27 | + a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND, |
| 28 | + axis: None = None, |
| 29 | + b: onp.ToFloat64 | onp.ToFloat64_ND | None = None, |
| 30 | + keepdims: Falsy = False, |
| 31 | + return_sign: Falsy = False, |
13 | 32 | ) -> np.float64: ...
|
14 |
| -@overload |
| 33 | +@overload # 0d/nd ~complex, axis=None (default), keepdims=False (default) |
15 | 34 | def logsumexp(
|
16 |
| - a: onp.ToComplex, |
| 35 | + a: onp.ToJustComplex128 | onp.ToJustComplex128_ND, |
| 36 | + axis: None = None, |
| 37 | + b: onp.ToComplex128 | onp.ToComplex128_ND | None = None, |
| 38 | + keepdims: Falsy = False, |
| 39 | + return_sign: Falsy = False, |
| 40 | +) -> np.complex128: ... |
| 41 | +@overload # 0d/nd T, keepdims=True |
| 42 | +def logsumexp( |
| 43 | + a: _InexactT | onp.ToArrayND[_InexactT, _InexactT], |
17 | 44 | axis: AnyShape | None = None,
|
18 |
| - b: onp.ToFloat | None = None, |
19 |
| - keepdims: bool = False, |
| 45 | + b: onp.ToComplex | onp.ToComplexND | None = None, |
| 46 | + *, |
| 47 | + keepdims: Truthy, |
20 | 48 | return_sign: Falsy = False,
|
21 |
| -) -> np.float64 | np.complex128: ... |
22 |
| -@overload |
| 49 | +) -> onp.ArrayND[_InexactT]: ... |
| 50 | +@overload # 0d/nd +float, keepdims=True |
| 51 | +def logsumexp( |
| 52 | + a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND, |
| 53 | + axis: AnyShape | None = None, |
| 54 | + b: onp.ToFloat64 | onp.ToFloat64_ND | None = None, |
| 55 | + *, |
| 56 | + keepdims: Truthy, |
| 57 | + return_sign: Falsy = False, |
| 58 | +) -> onp.ArrayND[np.float64]: ... |
| 59 | +@overload # 0d/nd ~complex, keepdims=True |
| 60 | +def logsumexp( |
| 61 | + a: onp.ToJustComplex128 | onp.ToJustComplex128_ND, |
| 62 | + axis: AnyShape | None = None, |
| 63 | + b: onp.ToComplex128 | onp.ToComplex128_ND | None = None, |
| 64 | + *, |
| 65 | + keepdims: Truthy, |
| 66 | + return_sign: Falsy = False, |
| 67 | +) -> onp.ArrayND[np.complex128]: ... |
| 68 | +@overload # 0d/nd T, axis=<given> |
| 69 | +def logsumexp( |
| 70 | + a: _InexactT | onp.ToArrayND[_InexactT, _InexactT], |
| 71 | + axis: AnyShape, |
| 72 | + b: onp.ToComplex | onp.ToComplexND | None = None, |
| 73 | + *, |
| 74 | + keepdims: Falsy = False, |
| 75 | + return_sign: Falsy = False, |
| 76 | +) -> onp.ArrayND[_InexactT] | Any: ... |
| 77 | +@overload # 0d/nd +float, axis=<given> |
| 78 | +def logsumexp( |
| 79 | + a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND, |
| 80 | + axis: AnyShape, |
| 81 | + b: onp.ToFloat64 | onp.ToFloat64_ND | None = None, |
| 82 | + keepdims: Falsy = False, |
| 83 | + return_sign: Falsy = False, |
| 84 | +) -> onp.ArrayND[np.float64] | Any: ... |
| 85 | +@overload # 0d/nd ~complex, axis=<given> |
| 86 | +def logsumexp( |
| 87 | + a: onp.ToJustComplex128 | onp.ToJustComplex128_ND, |
| 88 | + axis: AnyShape, |
| 89 | + b: onp.ToComplex128 | onp.ToComplex128_ND | None = None, |
| 90 | + keepdims: Falsy = False, |
| 91 | + return_sign: Falsy = False, |
| 92 | +) -> onp.ArrayND[np.complex128] | Any: ... |
| 93 | +@overload # floating fallback, return_sign=False |
23 | 94 | def logsumexp(
|
24 |
| - a: onp.ToFloatND, |
| 95 | + a: onp.ToFloat | onp.ToFloatND, |
25 | 96 | axis: AnyShape | None = None,
|
26 | 97 | b: onp.ToFloat | onp.ToFloatND | None = None,
|
27 | 98 | keepdims: bool = False,
|
28 | 99 | return_sign: Falsy = False,
|
29 |
| -) -> np.float64 | onp.ArrayND[np.float64]: ... |
30 |
| -@overload |
| 100 | +) -> onp.ArrayND[np.float64 | Any] | Any: ... |
| 101 | +@overload # ccomplex fallback, return_sign=False |
31 | 102 | def logsumexp(
|
32 |
| - a: onp.ToComplexND, |
| 103 | + a: onp.ToComplex | onp.ToComplexND, |
33 | 104 | axis: AnyShape | None = None,
|
34 |
| - b: onp.ToFloat | onp.ToFloatND | None = None, |
| 105 | + b: onp.ToComplex | onp.ToComplexND | None = None, |
35 | 106 | keepdims: bool = False,
|
36 | 107 | return_sign: Falsy = False,
|
37 |
| -) -> np.float64 | np.complex128 | onp.ArrayND[np.float64 | np.complex128]: ... |
38 |
| -@overload |
| 108 | +) -> onp.ArrayND[np.complex128 | Any] | Any: ... |
| 109 | +@overload # 0d/nd T@floating, axis=None (default), keepdims=False (default), return_sign=True |
39 | 110 | def logsumexp(
|
40 |
| - a: onp.ToFloat, axis: AnyShape | None = None, b: onp.ToFloat | None = None, keepdims: bool = False, *, return_sign: Truthy |
41 |
| -) -> tuple[np.float64, bool | np.bool_]: ... |
42 |
| -@overload |
| 111 | + a: _FloatingT | onp.ToArrayND[_FloatingT, _FloatingT], |
| 112 | + axis: None = None, |
| 113 | + b: onp.ToFloat | onp.ToFloatND | None = None, |
| 114 | + keepdims: Falsy = False, |
| 115 | + *, |
| 116 | + return_sign: Truthy, |
| 117 | +) -> tuple[_FloatingT, _FloatingT]: ... |
| 118 | +@overload # 0d/nd +float , axis=None (default), keepdims=False (default), return_sign=True |
43 | 119 | def logsumexp(
|
44 |
| - a: onp.ToComplex, axis: AnyShape | None = None, b: onp.ToFloat | None = None, keepdims: bool = False, *, return_sign: Truthy |
45 |
| -) -> tuple[np.float64 | np.complex128, bool | np.bool_]: ... |
46 |
| -@overload |
| 120 | + a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND, |
| 121 | + axis: None = None, |
| 122 | + b: onp.ToFloat64 | onp.ToFloat64_ND | None = None, |
| 123 | + keepdims: Falsy = False, |
| 124 | + *, |
| 125 | + return_sign: Truthy, |
| 126 | +) -> tuple[np.float64, np.float64]: ... |
| 127 | +@overload # 0d/nd ~complex, axis=None (default), keepdims=False (default), return_sign=True |
| 128 | +def logsumexp( |
| 129 | + a: onp.ToJustComplex128 | onp.ToJustComplex128_ND, |
| 130 | + axis: None = None, |
| 131 | + b: onp.ToComplex128 | onp.ToComplex128_ND | None = None, |
| 132 | + keepdims: Falsy = False, |
| 133 | + *, |
| 134 | + return_sign: Truthy, |
| 135 | +) -> tuple[np.float64, np.complex128]: ... |
| 136 | +@overload # 0d/nd T@complexfloating, axis=None (default), keepdims=False (default), return_sign=True |
| 137 | +def logsumexp( |
| 138 | + a: _CFloatingT | onp.ToArrayND[_CFloatingT, _CFloatingT], |
| 139 | + axis: None = None, |
| 140 | + b: onp.ToFloat | onp.ToFloatND | None = None, |
| 141 | + keepdims: Falsy = False, |
| 142 | + *, |
| 143 | + return_sign: Truthy, |
| 144 | +) -> tuple[npc.floating, _CFloatingT]: ... |
| 145 | +@overload # 0d/nd T@floatinv, keepdims=True, return_sign=True |
47 | 146 | def logsumexp(
|
48 |
| - a: onp.ToFloatND, |
| 147 | + a: _FloatingT | onp.ToArrayND[_FloatingT, _FloatingT], |
49 | 148 | axis: AnyShape | None = None,
|
50 | 149 | b: onp.ToFloat | onp.ToFloatND | None = None,
|
51 |
| - keepdims: bool = False, |
52 | 150 | *,
|
| 151 | + keepdims: Truthy, |
53 | 152 | return_sign: Truthy,
|
54 |
| -) -> tuple[np.float64, bool | np.bool_] | tuple[onp.ArrayND[np.float64], onp.ArrayND[np.bool_]]: ... |
55 |
| -@overload |
| 153 | +) -> tuple[onp.ArrayND[_FloatingT], onp.ArrayND[_FloatingT]]: ... |
| 154 | +@overload # 0d/nd +float, keepdims=True, return_sign=True |
| 155 | +def logsumexp( |
| 156 | + a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND, |
| 157 | + axis: AnyShape | None = None, |
| 158 | + b: onp.ToFloat64 | onp.ToFloat64_ND | None = None, |
| 159 | + *, |
| 160 | + keepdims: Truthy, |
| 161 | + return_sign: Truthy, |
| 162 | +) -> tuple[onp.ArrayND[np.float64], onp.ArrayND[np.float64]]: ... |
| 163 | +@overload # 0d/nd ~complex, keepdims=True, return_sign=True |
| 164 | +def logsumexp( |
| 165 | + a: onp.ToJustComplex128 | onp.ToJustComplex128_ND, |
| 166 | + axis: AnyShape | None = None, |
| 167 | + b: onp.ToComplex128 | onp.ToComplex128_ND | None = None, |
| 168 | + *, |
| 169 | + keepdims: Truthy, |
| 170 | + return_sign: Truthy, |
| 171 | +) -> tuple[onp.ArrayND[np.float64], onp.ArrayND[np.complex128]]: ... |
| 172 | +@overload # 0d/nd T@complexfloating, keepdims=True, return_sign=True |
| 173 | +def logsumexp( |
| 174 | + a: _CFloatingT | onp.ToArrayND[_CFloatingT, _CFloatingT], |
| 175 | + axis: AnyShape | None = None, |
| 176 | + b: onp.ToComplex | onp.ToComplexND | None = None, |
| 177 | + *, |
| 178 | + keepdims: Truthy, |
| 179 | + return_sign: Truthy, |
| 180 | +) -> tuple[onp.ArrayND[npc.floating], onp.ArrayND[_CFloatingT]]: ... |
| 181 | +@overload # 0d/nd T@floatinv, axis=<given>, return_sign=True |
| 182 | +def logsumexp( |
| 183 | + a: _FloatingT | onp.ToArrayND[_FloatingT, _FloatingT], |
| 184 | + axis: AnyShape, |
| 185 | + b: onp.ToFloat | onp.ToFloatND | None = None, |
| 186 | + keepdims: Falsy = False, |
| 187 | + *, |
| 188 | + return_sign: Truthy, |
| 189 | +) -> tuple[onp.ArrayND[_FloatingT] | Any, onp.ArrayND[_FloatingT] | Any]: ... |
| 190 | +@overload # 0d/nd +float, axis=<given>, return_sign=True |
56 | 191 | def logsumexp(
|
57 |
| - a: onp.ToComplexND, |
| 192 | + a: onp.ToInt | onp.ToIntND | onp.ToJustFloat64 | onp.ToJustFloat64_ND, |
| 193 | + axis: AnyShape, |
| 194 | + b: onp.ToFloat64 | onp.ToFloat64_ND | None = None, |
| 195 | + keepdims: Falsy = False, |
| 196 | + *, |
| 197 | + return_sign: Truthy, |
| 198 | +) -> tuple[onp.ArrayND[np.float64] | Any, onp.ArrayND[np.float64] | Any]: ... |
| 199 | +@overload # 0d/nd ~complex, axis=<given>, return_sign=True |
| 200 | +def logsumexp( |
| 201 | + a: onp.ToJustComplex128 | onp.ToJustComplex128_ND, |
| 202 | + axis: AnyShape, |
| 203 | + b: onp.ToComplex128 | onp.ToComplex128_ND | None = None, |
| 204 | + keepdims: Falsy = False, |
| 205 | + *, |
| 206 | + return_sign: Truthy, |
| 207 | +) -> tuple[onp.ArrayND[np.float64] | Any, onp.ArrayND[np.complex128] | Any]: ... |
| 208 | +@overload # 0d/nd T@complexfloating, axis=<given>, return_sign=True |
| 209 | +def logsumexp( |
| 210 | + a: _CFloatingT | onp.ToArrayND[_CFloatingT, _CFloatingT], |
| 211 | + axis: AnyShape, |
| 212 | + b: onp.ToComplex | onp.ToComplexND | None = None, |
| 213 | + keepdims: Falsy = False, |
| 214 | + *, |
| 215 | + return_sign: Truthy, |
| 216 | +) -> tuple[onp.ArrayND[npc.floating] | Any, onp.ArrayND[_CFloatingT] | Any]: ... |
| 217 | +@overload # floating fallback, return_sign=True |
| 218 | +def logsumexp( |
| 219 | + a: onp.ToFloat | onp.ToFloatND, |
58 | 220 | axis: AnyShape | None = None,
|
59 | 221 | b: onp.ToFloat | onp.ToFloatND | None = None,
|
60 | 222 | keepdims: bool = False,
|
61 | 223 | *,
|
62 | 224 | return_sign: Truthy,
|
63 |
| -) -> ( |
64 |
| - tuple[np.float64 | np.complex128, bool | np.bool_] | tuple[onp.ArrayND[np.float64 | np.complex128], onp.ArrayND[np.bool_]] |
65 |
| -): ... |
| 225 | +) -> tuple[onp.ArrayND[np.float64 | Any] | Any, onp.ArrayND[np.float64 | Any] | Any]: ... |
| 226 | +@overload # ccomplex fallback, return_sign=True |
| 227 | +def logsumexp( |
| 228 | + a: onp.ToComplex | onp.ToComplexND, |
| 229 | + axis: AnyShape | None = None, |
| 230 | + b: onp.ToComplex | onp.ToComplexND | None = None, |
| 231 | + keepdims: bool = False, |
| 232 | + *, |
| 233 | + return_sign: Truthy, |
| 234 | +) -> tuple[onp.ArrayND[np.float64 | Any] | Any, onp.ArrayND[np.complex128 | Any] | Any]: ... |
66 | 235 |
|
67 | 236 | #
|
68 | 237 | @overload
|
|
0 commit comments