Skip to content

Commit bfa5748

Browse files
emilyfertigtensorflower-gardener
authored andcommitted
Fix handling of empty arrays by reduce_logsumexp in Numpy backend. Do not recurse through Numpy functions in copy_docstring.
PiperOrigin-RevId: 456430771
1 parent 3f5749a commit bfa5748

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

tensorflow_probability/python/internal/backend/numpy/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _find_method_from_name(scope, name):
4646
child = scope[method[0]]
4747
else:
4848
child = getattr(scope, method[0])
49-
if len(method) == 1:
49+
if len(method) == 1 or (len(method) == 2 and method[0] == 'np'):
5050
return child
5151
return _find_method_from_name(child, method[1])
5252

tensorflow_probability/python/internal/backend/numpy/numpy_math.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,9 @@ def _lbeta(x, name=None): # pylint: disable=unused-argument
285285
def _max_mask_non_finite(x, axis=-1, keepdims=False, mask=0):
286286
"""Returns `max` or `mask` if `max` is not finite."""
287287
x = _convert_to_tensor(x)
288+
if x.size == 0:
289+
# Return `-inf` if `x` is empty for consistency with `tf.reduce_max`.
290+
return -np.inf
288291
m = np.max(x, axis=_astuple(axis), keepdims=keepdims)
289292
needs_masking = ~np.isfinite(m)
290293
if needs_masking.ndim > 0:
@@ -356,6 +359,16 @@ def _softmax(logits, axis=None, name=None): # pylint: disable=unused-argument
356359
return y
357360

358361

362+
def _alt_reduce_logsumexp(input_tensor, axis=None, keepdims=False):
363+
"""Alternative to SP logsumexp."""
364+
m = _max_mask_non_finite(input_tensor, axis=axis, keepdims=True)
365+
y = input_tensor - m
366+
y = np.exp(y, out=y)
367+
if not keepdims:
368+
m = np.squeeze(m, axis=_astuple(axis))
369+
return m + np.log(np.sum(y, axis=_astuple(axis), keepdims=keepdims))
370+
371+
359372
def _reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None): # pylint: disable=unused-argument
360373
"""Computes `log(sum(exp(input_tensor))) along the specified axis."""
361374
input_tensor = _convert_to_tensor(input_tensor)
@@ -364,18 +377,17 @@ def _reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None): # py
364377
or np.issubdtype(dtype, np.complexfloating)):
365378
# Match TF error
366379
raise TypeError('Input must be either real or complex')
380+
if input_tensor.size == 0:
381+
# On empty arrays, mimic TF in returning `-inf` instead of failing, and
382+
# preserve error message if `axis` arg is incompatible with an empty array.
383+
return _alt_reduce_logsumexp(input_tensor, axis=axis, keepdims=keepdims)
367384
try:
368385
return scipy_special.logsumexp(
369386
input_tensor, axis=_astuple(axis), keepdims=keepdims)
370387
except NotImplementedError:
371388
# We offer a non SP version just in case SP isn't installed and this
372389
# because logsumexp is often used.
373-
m = _max_mask_non_finite(input_tensor, axis=axis, keepdims=True)
374-
y = input_tensor - m
375-
y = np.exp(y, out=y)
376-
if not keepdims:
377-
m = np.squeeze(m, axis=_astuple(axis))
378-
return m + np.log(np.sum(y, axis=_astuple(axis), keepdims=keepdims))
390+
return _alt_reduce_logsumexp(input_tensor, axis=axis, keepdims=keepdims)
379391

380392

381393
# Match the TF return type for top_k.

0 commit comments

Comments
 (0)