Skip to content

Commit 714d547

Browse files
urskjburnim
authored andcommitted
Allow unit tests to disable 64 bit precision mode in jax.
PiperOrigin-RevId: 550617289
1 parent 3f42739 commit 714d547

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tensorflow_probability/python/internal/test_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,9 +2010,9 @@ def getTestCaseNames(self, testCaseClass): # pylint:disable=invalid-name
20102010
return names
20112011

20122012

2013-
def main(jax_mode=JAX_MODE):
2013+
def main(jax_mode=JAX_MODE, jax_enable_x64=True):
20142014
"""Test main function that injects a custom loader."""
2015-
if jax_mode:
2015+
if jax_mode and jax_enable_x64:
20162016
from jax.config import config # pylint: disable=g-import-not-at-top
20172017
config.update('jax_enable_x64', True)
20182018

0 commit comments

Comments
 (0)