Skip to content

Commit 7bcb5aa

Browse files
brianwa84tensorflower-gardener
authored andcommitted
Fix marginal_fns_test in jax.
PiperOrigin-RevId: 817662761
1 parent c28a71c commit 7bcb5aa

File tree

1 file changed

+9
-20
lines changed

1 file changed

+9
-20
lines changed

tensorflow_probability/python/experimental/distributions/marginal_fns_test.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,6 @@
2222
from tensorflow_probability.python.math import gradient
2323

2424

25-
JAX_MODE = False
26-
27-
if JAX_MODE:
28-
import jax # pylint: disable=g-import-not-at-top
29-
# TODO(b/415014385): Remove this config once the bug is fixed.
30-
jax.config.update('jax_use_direct_linearize', False)
31-
32-
3325
def _value_and_grads(f, x, has_aux=False):
3426
val, grad = gradient.value_and_gradient(f, x, has_aux=has_aux)
3527
_, grad_of_grad = gradient.value_and_gradient(
@@ -76,8 +68,8 @@ def testRetryingCholeskyWithBatchAndXLA(self):
7668
matrix, has_aux=True))
7769
self.assertAllEqual(expected, res)
7870
self.assertAllClose(expected_shift[..., 0, 0], shift)
79-
self.assertAllEqual(expected_grad, grad)
80-
self.assertAllEqual(expected_grad_of_grad, grad_of_grad)
71+
self.assertAllClose(expected_grad, grad)
72+
self.assertAllClose(expected_grad_of_grad, grad_of_grad, rtol=2e-6)
8173

8274
# Test value and gradients of XLA-compiled `retrying_cholesky`.
8375
xla_retrying_cholesky = tf.function(
@@ -119,12 +111,8 @@ def testRetryingCholeskyFloat64(self):
119111
marginal_fns.retrying_cholesky, matrix, has_aux=True))
120112
self.assertAllEqual(expected, res)
121113
self.assertAllClose(expected_shift[..., 0, 0], shift)
122-
self.assertAllEqual(expected_grad, grad)
123-
self.assertAllEqual(expected_grad_of_grad, grad_of_grad)
124-
125-
expected, expected_grad, expected_grad_of_grad = self.evaluate(
126-
_value_and_grads(lambda x: tf.linalg.cholesky(x + expected_shift),
127-
matrix))
114+
self.assertAllClose(expected_grad, grad)
115+
self.assertAllClose(expected_grad_of_grad, grad_of_grad)
128116

129117
@test_util.disable_test_for_backend(
130118
disable_numpy=True, reason='No gradients available in numpy.')
@@ -149,11 +137,12 @@ def testRetryingCholeskyFailures(self):
149137
lambda x: marginal_fns.retrying_cholesky(x, max_iters=6),
150138
matrix, has_aux=True))
151139

152-
self.assertAllEqual([expected[0], expected[2]], [res[0], res[2]])
153-
self.assertAllEqual([expected_grad[0], expected_grad[2]],
140+
self.assertAllClose([expected[0], expected[2]], [res[0], res[2]])
141+
self.assertAllClose([expected_grad[0], expected_grad[2]],
154142
[grad[0], grad[2]])
155-
self.assertAllEqual([expected_grad_of_grad[0], expected_grad_of_grad[2]],
156-
[grad_of_grad[0], grad_of_grad[2]])
143+
self.assertAllClose([expected_grad_of_grad[0], expected_grad_of_grad[2]],
144+
[grad_of_grad[0], grad_of_grad[2]],
145+
rtol=2e-6)
157146

158147
# Check that the lower-triangular part of `res[1]` is NaN.
159148
for i in range(res[1].shape[0]):

0 commit comments

Comments
 (0)