@@ -285,6 +285,9 @@ def _lbeta(x, name=None): # pylint: disable=unused-argument
285
285
def _max_mask_non_finite (x , axis = - 1 , keepdims = False , mask = 0 ):
286
286
"""Returns `max` or `mask` if `max` is not finite."""
287
287
x = _convert_to_tensor (x )
288
+ if x .size == 0 :
289
+ # Return `-inf` if `x` is empty for consistency with `tf.reduce_max`.
290
+ return - np .inf
288
291
m = np .max (x , axis = _astuple (axis ), keepdims = keepdims )
289
292
needs_masking = ~ np .isfinite (m )
290
293
if needs_masking .ndim > 0 :
@@ -356,6 +359,16 @@ def _softmax(logits, axis=None, name=None): # pylint: disable=unused-argument
356
359
return y
357
360
358
361
362
+ def _alt_reduce_logsumexp (input_tensor , axis = None , keepdims = False ):
363
+ """Alternative to SP logsumexp."""
364
+ m = _max_mask_non_finite (input_tensor , axis = axis , keepdims = True )
365
+ y = input_tensor - m
366
+ y = np .exp (y , out = y )
367
+ if not keepdims :
368
+ m = np .squeeze (m , axis = _astuple (axis ))
369
+ return m + np .log (np .sum (y , axis = _astuple (axis ), keepdims = keepdims ))
370
+
371
+
359
372
def _reduce_logsumexp (input_tensor , axis = None , keepdims = False , name = None ): # pylint: disable=unused-argument
360
373
"""Computes `log(sum(exp(input_tensor))) along the specified axis."""
361
374
input_tensor = _convert_to_tensor (input_tensor )
@@ -364,18 +377,17 @@ def _reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None): # py
364
377
or np .issubdtype (dtype , np .complexfloating )):
365
378
# Match TF error
366
379
raise TypeError ('Input must be either real or complex' )
380
+ if input_tensor .size == 0 :
381
+ # On empty arrays, mimic TF in returning `-inf` instead of failing, and
382
+ # preserve error message if `axis` arg is incompatible with an empty array.
383
+ return _alt_reduce_logsumexp (input_tensor , axis = axis , keepdims = keepdims )
367
384
try :
368
385
return scipy_special .logsumexp (
369
386
input_tensor , axis = _astuple (axis ), keepdims = keepdims )
370
387
except NotImplementedError :
371
388
# We offer a non SP version just in case SP isn't installed and this
372
389
# because logsumexp is often used.
373
- m = _max_mask_non_finite (input_tensor , axis = axis , keepdims = True )
374
- y = input_tensor - m
375
- y = np .exp (y , out = y )
376
- if not keepdims :
377
- m = np .squeeze (m , axis = _astuple (axis ))
378
- return m + np .log (np .sum (y , axis = _astuple (axis ), keepdims = keepdims ))
390
+ return _alt_reduce_logsumexp (input_tensor , axis = axis , keepdims = keepdims )
379
391
380
392
381
393
# Match the TF return type for top_k.
0 commit comments