Skip to content

Commit 611cd72

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Fix nptf.reduce_logsumexp when scipy is not available.
PiperOrigin-RevId: 384354674
1 parent 36d8854 commit 611cd72

File tree

2 files changed

+44
-14
lines changed

2 files changed

+44
-14
lines changed

tensorflow_probability/python/internal/backend/numpy/numpy_math.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,8 @@ def _reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None): # py
376376
m = _max_mask_non_finite(input_tensor, axis=axis, keepdims=True)
377377
y = input_tensor - m
378378
y = np.exp(y, out=y)
379+
if not keepdims:
380+
m = np.squeeze(m, axis=_astuple(axis))
379381
return m + np.log(np.sum(y, axis=_astuple(axis), keepdims=keepdims))
380382

381383

tensorflow_probability/python/internal/backend/numpy/numpy_test.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@
2929
import hypothesis as hp
3030
import hypothesis.extra.numpy as hnp
3131
import hypothesis.strategies as hps
32+
import mock
3233
import numpy as np # Rewritten by script to import jax.numpy
3334
import numpy as onp # pylint: disable=reimported
35+
import scipy.special as scipy_special
3436
import six
3537
import tensorflow.compat.v1 as tf1
3638
import tensorflow.compat.v2 as tf
@@ -89,13 +91,21 @@ class TestCase(dict):
8991
def __init__(self, name, strategy_list, **kwargs):
9092
self.name = name
9193

94+
tensorflow_function = kwargs.pop('tensorflow_function', None)
95+
if not tensorflow_function:
96+
tensorflow_function = _getattr(tf, name)
97+
98+
numpy_function = kwargs.pop('numpy_function', None)
99+
if not numpy_function:
100+
numpy_function = _getattr(
101+
nptf,
102+
name.replace('random.', 'random.stateless_'
103+
).replace('random.stateless_gamma', 'random.gamma'))
104+
92105
super(TestCase, self).__init__(
93106
testcase_name='_' + name.replace('.', '_'),
94-
tensorflow_function=_getattr(tf, name),
95-
numpy_function=_getattr(
96-
nptf,
97-
name.replace('random.', 'random.stateless_'
98-
).replace('random.stateless_gamma', 'random.gamma')),
107+
tensorflow_function=tensorflow_function,
108+
numpy_function=numpy_function,
99109
strategy_list=strategy_list,
100110
name=name,
101111
**kwargs)
@@ -677,6 +687,14 @@ def _eig_post_process(vals):
677687
return np.einsum('...ab,...b,...bc->...ac', v, e, v.swapaxes(-1, -2))
678688

679689

690+
def _reduce_logsumexp_no_scipy(*args, **kwargs):
691+
def _not_implemented(*args, **kwargs):
692+
raise NotImplementedError()
693+
694+
with mock.patch.object(scipy_special, 'logsumexp', _not_implemented):
695+
return nptf.reduce_logsumexp(*args, **kwargs)
696+
697+
680698
# __Currently untested:__
681699
# broadcast_dynamic_shape
682700
# broadcast_static_shape
@@ -812,17 +830,19 @@ def _eig_post_process(vals):
812830
# keywords=None,
813831
# defaults=(False, False, False, False, False, False, None))
814832
TestCase('linalg.matmul', [matmul_compatible_pairs()]),
815-
TestCase('linalg.eig', [pd_matrices()], post_processor=_eig_post_process,
816-
xla_disabled=True),
833+
TestCase(
834+
'linalg.eig', [pd_matrices()],
835+
post_processor=_eig_post_process,
836+
xla_disabled=True),
817837
TestCase('linalg.eigh', [pd_matrices()], post_processor=_eig_post_process),
818-
TestCase('linalg.eigvals', [pd_matrices()],
819-
post_processor=_eig_post_process, xla_disabled=True),
820-
TestCase('linalg.eigvalsh', [pd_matrices()],
821-
post_processor=_eig_post_process),
822838
TestCase(
823-
'linalg.det',
824-
[nonsingular_matrices()],
825-
rtol=1e-3,
839+
'linalg.eigvals', [pd_matrices()],
840+
post_processor=_eig_post_process,
841+
xla_disabled=True),
842+
TestCase(
843+
'linalg.eigvalsh', [pd_matrices()], post_processor=_eig_post_process),
844+
TestCase(
845+
'linalg.det', [nonsingular_matrices()], rtol=1e-3,
826846
xla_disabled=True), # TODO(b/162937268): missing kernel.
827847

828848
# ArgSpec(args=['a', 'name', 'conjugate'], varargs=None, keywords=None)
@@ -963,6 +983,14 @@ def _eig_post_process(vals):
963983
TestCase(
964984
'math.reduce_logsumexp', [array_axis_tuples(allow_multi_axis=True)],
965985
xla_const_args=(1,)),
986+
TestCase(
987+
'math.reduce_logsumexp_no_scipy',
988+
[array_axis_tuples(allow_multi_axis=True)],
989+
xla_const_args=(1,),
990+
tensorflow_function=tf.math.reduce_logsumexp,
991+
numpy_function=_reduce_logsumexp_no_scipy,
992+
disabled=JAX_MODE, # JAX always has scipy.
993+
),
966994
TestCase(
967995
'math.reduce_max', # TODO(b/171070692): TF produces nonsense with NaN.
968996
[array_axis_tuples(allow_nan=False, allow_multi_axis=True)],

0 commit comments

Comments
 (0)