Skip to content

Commit 303e844

Browse files
danielsuotensorflower-gardener
authored andcommitted
Opt out of JAX's new direct linearizer (i.e., jax_use_direct_linearize=False) for tests that have small numerical differences.
PiperOrigin-RevId: 762192622
1 parent 56dcbbd commit 303e844

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

tensorflow_probability/python/experimental/distributions/marginal_fns_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@
2222
from tensorflow_probability.python.math import gradient
2323

2424

25+
JAX_MODE = False
26+
27+
if JAX_MODE:
28+
import jax # pylint: disable=g-import-not-at-top
29+
# TODO(b/415014385): Remove this config once the bug is fixed.
30+
jax.config.update('jax_use_direct_linearize', False)
31+
32+
2533
def _value_and_grads(f, x, has_aux=False):
2634
val, grad = gradient.value_and_gradient(f, x, has_aux=has_aux)
2735
_, grad_of_grad = gradient.value_and_gradient(

0 commit comments

Comments
 (0)