@@ -11,6 +11,7 @@ __all__ = ["log_softmax", "logsumexp", "softmax"]
11
11
_InexactT = TypeVar ("_InexactT" , bound = npc .inexact )
12
12
_FloatingT = TypeVar ("_FloatingT" , bound = npc .floating )
13
13
_CFloatingT = TypeVar ("_CFloatingT" , bound = npc .complexfloating )
14
+ _InexactOrArrayT = TypeVar ("_InexactOrArrayT" , bound = npc .inexact | onp .ArrayND [npc .inexact ])
14
15
15
16
###
16
17
@@ -98,7 +99,7 @@ def logsumexp(
98
99
keepdims : bool = False ,
99
100
return_sign : Falsy = False ,
100
101
) -> onp .ArrayND [np .float64 | Any ] | Any : ...
101
- @overload # ccomplex fallback, return_sign=False
102
+ @overload # complex fallback, return_sign=False
102
103
def logsumexp (
103
104
a : onp .ToComplex | onp .ToComplexND ,
104
105
axis : AnyShape | None = None ,
@@ -223,7 +224,7 @@ def logsumexp(
223
224
* ,
224
225
return_sign : Truthy ,
225
226
) -> tuple [onp .ArrayND [np .float64 | Any ] | Any , onp .ArrayND [np .float64 | Any ] | Any ]: ...
226
- @overload # ccomplex fallback, return_sign=True
227
+ @overload # complex fallback, return_sign=True
227
228
def logsumexp (
228
229
a : onp .ToComplex | onp .ToComplexND ,
229
230
axis : AnyShape | None = None ,
@@ -233,22 +234,46 @@ def logsumexp(
233
234
return_sign : Truthy ,
234
235
) -> tuple [onp .ArrayND [np .float64 | Any ] | Any , onp .ArrayND [np .complex128 | Any ] | Any ]: ...
235
236
236
- #
237
- @overload
238
- def softmax (x : onp .ToFloat , axis : AnyShape | None = None ) -> np .float64 : ...
239
- @overload
240
- def softmax (x : onp .ToFloatND , axis : AnyShape | None = None ) -> onp .ArrayND [np .float64 ]: ...
241
- @overload
242
- def softmax (x : onp .ToComplex , axis : AnyShape | None = None ) -> np .float64 | np .complex128 : ...
243
- @overload
244
- def softmax (x : onp .ToComplexND , axis : AnyShape | None = None ) -> onp .ArrayND [np .float64 | np .complex128 ]: ...
237
+ # NOTE: keep in sync with `log_softmax`
238
+ @overload # T
239
+ def softmax (x : _InexactOrArrayT , axis : AnyShape | None = None ) -> _InexactOrArrayT : ... # type: ignore[overload-overlap]
240
+ @overload # 0d +float64
241
+ def softmax (x : onp .ToInt | onp .ToJustFloat64 , axis : AnyShape | None = None ) -> np .float64 : ...
242
+ @overload # 0d ~complex128
243
+ def softmax (x : onp .ToJustComplex128 , axis : AnyShape | None = None ) -> np .complex128 : ...
244
+ @overload # nd T@inexact
245
+ def softmax (x : onp .ToArrayND [_InexactT , _InexactT ], axis : AnyShape | None = None ) -> onp .ArrayND [_InexactT ]: ...
246
+ @overload # nd +float64
247
+ def softmax (x : onp .ToIntND | onp .ToJustFloat64_ND , axis : AnyShape | None = None ) -> onp .ArrayND [np .float64 ]: ...
248
+ @overload # nd ~complex128
249
+ def softmax (x : onp .ToJustComplex128_ND , axis : AnyShape | None = None ) -> onp .ArrayND [np .complex128 ]: ...
250
+ @overload # 0d float fallback
251
+ def softmax (x : onp .ToFloat , axis : AnyShape | None = None ) -> np .float64 | Any : ...
252
+ @overload # 0d complex fallback
253
+ def softmax (x : onp .ToComplex , axis : AnyShape | None = None ) -> np .complex128 | Any : ...
254
+ @overload # nd float fallback
255
+ def softmax (x : onp .ToFloatND , axis : AnyShape | None = None ) -> onp .ArrayND [np .float64 | Any ]: ...
256
+ @overload # nd complex fallback
257
+ def softmax (x : onp .ToComplexND , axis : AnyShape | None = None ) -> onp .ArrayND [np .complex128 | Any ]: ...
245
258
246
- #
247
- @overload
248
- def log_softmax (x : onp .ToFloat , axis : AnyShape | None = None ) -> np .float64 : ...
249
- @overload
250
- def log_softmax (x : onp .ToFloatND , axis : AnyShape | None = None ) -> onp .ArrayND [np .float64 ]: ...
251
- @overload
252
- def log_softmax (x : onp .ToComplex , axis : AnyShape | None = None ) -> np .float64 | np .complex128 : ...
253
- @overload
254
- def log_softmax (x : onp .ToComplexND , axis : AnyShape | None = None ) -> onp .ArrayND [np .float64 | np .complex128 ]: ...
259
+ # NOTE: keep in sync with `softmax`
260
+ @overload # T
261
+ def log_softmax (x : _InexactOrArrayT , axis : AnyShape | None = None ) -> _InexactOrArrayT : ... # type: ignore[overload-overlap]
262
+ @overload # 0d +float64
263
+ def log_softmax (x : onp .ToInt | onp .ToJustFloat64 , axis : AnyShape | None = None ) -> np .float64 : ...
264
+ @overload # 0d ~complex128
265
+ def log_softmax (x : onp .ToJustComplex128 , axis : AnyShape | None = None ) -> np .complex128 : ...
266
+ @overload # nd T@inexact
267
+ def log_softmax (x : onp .ToArrayND [_InexactT , _InexactT ], axis : AnyShape | None = None ) -> onp .ArrayND [_InexactT ]: ...
268
+ @overload # nd +float64
269
+ def log_softmax (x : onp .ToIntND | onp .ToJustFloat64_ND , axis : AnyShape | None = None ) -> onp .ArrayND [np .float64 ]: ...
270
+ @overload # nd ~complex128
271
+ def log_softmax (x : onp .ToJustComplex128_ND , axis : AnyShape | None = None ) -> onp .ArrayND [np .complex128 ]: ...
272
+ @overload # 0d float fallback
273
+ def log_softmax (x : onp .ToFloat , axis : AnyShape | None = None ) -> np .float64 | Any : ...
274
+ @overload # 0d complex fallback
275
+ def log_softmax (x : onp .ToComplex , axis : AnyShape | None = None ) -> np .complex128 | Any : ...
276
+ @overload # nd float fallback
277
+ def log_softmax (x : onp .ToFloatND , axis : AnyShape | None = None ) -> onp .ArrayND [np .float64 | Any ]: ...
278
+ @overload # nd complex fallback
279
+ def log_softmax (x : onp .ToComplexND , axis : AnyShape | None = None ) -> onp .ArrayND [np .complex128 | Any ]: ...
0 commit comments