Skip to content

Commit eeb795a

Browse files
brianwa84tensorflower-gardener
authored andcommitted
Short-circuit _reduce_logsumexp when input is scalar, to avoid accidentally returning a [1]-shaped array.
PiperOrigin-RevId: 817734032
1 parent 66b0926 commit eeb795a

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

tensorflow_probability/python/experimental/math/manual_special_functions_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,11 @@ def test_patching(self, exp, log, expm1, log1p, logsumexp, softplus):
339339
log1p_calls += 1
340340
self.assertEqual(log1p_calls, log1p.call_count)
341341

342-
tf.math.reduce_logsumexp(0.)
342+
tf.math.reduce_logsumexp([0.])
343343
logsumexp_calls += 1
344344
self.assertEqual(logsumexp_calls, logsumexp.call_count)
345345

346-
tf.reduce_logsumexp(0.)
346+
tf.reduce_logsumexp([0.])
347347
logsumexp_calls += 1
348348
self.assertEqual(logsumexp_calls, logsumexp.call_count)
349349

tensorflow_probability/python/internal/backend/numpy/numpy_math.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,8 @@ def _reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None): # py
416416
or np.issubdtype(dtype, np.complexfloating)):
417417
# Match TF error
418418
raise TypeError('Input must be either real or complex')
419+
if not input_tensor.shape:
420+
return input_tensor
419421
if input_tensor.size == 0:
420422
# On empty arrays, mimic TF in returning `-inf` instead of failing, and
421423
# preserve error message if `axis` arg is incompatible with an empty array.

0 commit comments

Comments
 (0)