|
29 | 29 | import hypothesis as hp
|
30 | 30 | import hypothesis.extra.numpy as hnp
|
31 | 31 | import hypothesis.strategies as hps
|
| 32 | +import mock |
32 | 33 | import numpy as np # Rewritten by script to import jax.numpy
|
33 | 34 | import numpy as onp # pylint: disable=reimported
|
| 35 | +import scipy.special as scipy_special |
34 | 36 | import six
|
35 | 37 | import tensorflow.compat.v1 as tf1
|
36 | 38 | import tensorflow.compat.v2 as tf
|
@@ -89,13 +91,21 @@ class TestCase(dict):
|
89 | 91 | def __init__(self, name, strategy_list, **kwargs):
|
90 | 92 | self.name = name
|
91 | 93 |
|
| 94 | + tensorflow_function = kwargs.pop('tensorflow_function', None) |
| 95 | + if not tensorflow_function: |
| 96 | + tensorflow_function = _getattr(tf, name) |
| 97 | + |
| 98 | + numpy_function = kwargs.pop('numpy_function', None) |
| 99 | + if not numpy_function: |
| 100 | + numpy_function = _getattr( |
| 101 | + nptf, |
| 102 | + name.replace('random.', 'random.stateless_' |
| 103 | + ).replace('random.stateless_gamma', 'random.gamma')) |
| 104 | + |
92 | 105 | super(TestCase, self).__init__(
|
93 | 106 | testcase_name='_' + name.replace('.', '_'),
|
94 |
| - tensorflow_function=_getattr(tf, name), |
95 |
| - numpy_function=_getattr( |
96 |
| - nptf, |
97 |
| - name.replace('random.', 'random.stateless_' |
98 |
| - ).replace('random.stateless_gamma', 'random.gamma')), |
| 107 | + tensorflow_function=tensorflow_function, |
| 108 | + numpy_function=numpy_function, |
99 | 109 | strategy_list=strategy_list,
|
100 | 110 | name=name,
|
101 | 111 | **kwargs)
|
@@ -677,6 +687,14 @@ def _eig_post_process(vals):
|
677 | 687 | return np.einsum('...ab,...b,...bc->...ac', v, e, v.swapaxes(-1, -2))
|
678 | 688 |
|
679 | 689 |
|
| 690 | +def _reduce_logsumexp_no_scipy(*args, **kwargs): |
| 691 | + def _not_implemented(*args, **kwargs): |
| 692 | + raise NotImplementedError() |
| 693 | + |
| 694 | + with mock.patch.object(scipy_special, 'logsumexp', _not_implemented): |
| 695 | + return nptf.reduce_logsumexp(*args, **kwargs) |
| 696 | + |
| 697 | + |
680 | 698 | # __Currently untested:__
|
681 | 699 | # broadcast_dynamic_shape
|
682 | 700 | # broadcast_static_shape
|
@@ -812,17 +830,19 @@ def _eig_post_process(vals):
|
812 | 830 | # keywords=None,
|
813 | 831 | # defaults=(False, False, False, False, False, False, None))
|
814 | 832 | TestCase('linalg.matmul', [matmul_compatible_pairs()]),
|
815 |
| - TestCase('linalg.eig', [pd_matrices()], post_processor=_eig_post_process, |
816 |
| - xla_disabled=True), |
| 833 | + TestCase( |
| 834 | + 'linalg.eig', [pd_matrices()], |
| 835 | + post_processor=_eig_post_process, |
| 836 | + xla_disabled=True), |
817 | 837 | TestCase('linalg.eigh', [pd_matrices()], post_processor=_eig_post_process),
|
818 |
| - TestCase('linalg.eigvals', [pd_matrices()], |
819 |
| - post_processor=_eig_post_process, xla_disabled=True), |
820 |
| - TestCase('linalg.eigvalsh', [pd_matrices()], |
821 |
| - post_processor=_eig_post_process), |
822 | 838 | TestCase(
|
823 |
| - 'linalg.det', |
824 |
| - [nonsingular_matrices()], |
825 |
| - rtol=1e-3, |
| 839 | + 'linalg.eigvals', [pd_matrices()], |
| 840 | + post_processor=_eig_post_process, |
| 841 | + xla_disabled=True), |
| 842 | + TestCase( |
| 843 | + 'linalg.eigvalsh', [pd_matrices()], post_processor=_eig_post_process), |
| 844 | + TestCase( |
| 845 | + 'linalg.det', [nonsingular_matrices()], rtol=1e-3, |
826 | 846 | xla_disabled=True), # TODO(b/162937268): missing kernel.
|
827 | 847 |
|
828 | 848 | # ArgSpec(args=['a', 'name', 'conjugate'], varargs=None, keywords=None)
|
@@ -963,6 +983,14 @@ def _eig_post_process(vals):
|
963 | 983 | TestCase(
|
964 | 984 | 'math.reduce_logsumexp', [array_axis_tuples(allow_multi_axis=True)],
|
965 | 985 | xla_const_args=(1,)),
|
| 986 | + TestCase( |
| 987 | + 'math.reduce_logsumexp_no_scipy', |
| 988 | + [array_axis_tuples(allow_multi_axis=True)], |
| 989 | + xla_const_args=(1,), |
| 990 | + tensorflow_function=tf.math.reduce_logsumexp, |
| 991 | + numpy_function=_reduce_logsumexp_no_scipy, |
| 992 | + disabled=JAX_MODE, # JAX always has scipy. |
| 993 | + ), |
966 | 994 | TestCase(
|
967 | 995 | 'math.reduce_max', # TODO(b/171070692): TF produces nonsense with NaN.
|
968 | 996 | [array_axis_tuples(allow_nan=False, allow_multi_axis=True)],
|
|
0 commit comments