2222from tensorflow_probability .python .math import gradient
2323
2424
25- JAX_MODE = False
26-
27- if JAX_MODE :
28- import jax # pylint: disable=g-import-not-at-top
29- # TODO(b/415014385): Remove this config once the bug is fixed.
30- jax .config .update ('jax_use_direct_linearize' , False )
31-
32-
3325def _value_and_grads (f , x , has_aux = False ):
3426 val , grad = gradient .value_and_gradient (f , x , has_aux = has_aux )
3527 _ , grad_of_grad = gradient .value_and_gradient (
@@ -76,8 +68,8 @@ def testRetryingCholeskyWithBatchAndXLA(self):
7668 matrix , has_aux = True ))
7769 self .assertAllEqual (expected , res )
7870 self .assertAllClose (expected_shift [..., 0 , 0 ], shift )
79- self .assertAllEqual (expected_grad , grad )
80- self .assertAllEqual (expected_grad_of_grad , grad_of_grad )
71+ self .assertAllClose (expected_grad , grad )
72+ self .assertAllClose (expected_grad_of_grad , grad_of_grad , rtol = 2e-6 )
8173
8274 # Test value and gradients of XLA-compiled `retrying_cholesky`.
8375 xla_retrying_cholesky = tf .function (
@@ -119,12 +111,8 @@ def testRetryingCholeskyFloat64(self):
119111 marginal_fns .retrying_cholesky , matrix , has_aux = True ))
120112 self .assertAllEqual (expected , res )
121113 self .assertAllClose (expected_shift [..., 0 , 0 ], shift )
122- self .assertAllEqual (expected_grad , grad )
123- self .assertAllEqual (expected_grad_of_grad , grad_of_grad )
124-
125- expected , expected_grad , expected_grad_of_grad = self .evaluate (
126- _value_and_grads (lambda x : tf .linalg .cholesky (x + expected_shift ),
127- matrix ))
114+ self .assertAllClose (expected_grad , grad )
115+ self .assertAllClose (expected_grad_of_grad , grad_of_grad )
128116
129117 @test_util .disable_test_for_backend (
130118 disable_numpy = True , reason = 'No gradients available in numpy.' )
@@ -149,11 +137,12 @@ def testRetryingCholeskyFailures(self):
149137 lambda x : marginal_fns .retrying_cholesky (x , max_iters = 6 ),
150138 matrix , has_aux = True ))
151139
152- self .assertAllEqual ([expected [0 ], expected [2 ]], [res [0 ], res [2 ]])
153- self .assertAllEqual ([expected_grad [0 ], expected_grad [2 ]],
140+ self .assertAllClose ([expected [0 ], expected [2 ]], [res [0 ], res [2 ]])
141+ self .assertAllClose ([expected_grad [0 ], expected_grad [2 ]],
154142 [grad [0 ], grad [2 ]])
155- self .assertAllEqual ([expected_grad_of_grad [0 ], expected_grad_of_grad [2 ]],
156- [grad_of_grad [0 ], grad_of_grad [2 ]])
143+ self .assertAllClose ([expected_grad_of_grad [0 ], expected_grad_of_grad [2 ]],
144+ [grad_of_grad [0 ], grad_of_grad [2 ]],
145+ rtol = 2e-6 )
157146
158147 # Check that the lower-triangular part of `res[1]` is NaN.
159148 for i in range (res [1 ].shape [0 ]):
0 commit comments