Skip to content

Commit 73a4a75

Browse files
superbobrytensorflower-gardener
authored andcommitted
Updated jax.config import
PiperOrigin-RevId: 574931286
1 parent e16c0b7 commit 73a4a75

File tree

9 files changed

+9
-9
lines changed

9 files changed

+9
-9
lines changed

spinoffs/fun_mc/fun_mc/fun_mc_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from absl.testing import parameterized
2424
import jax
25-
from jax.config import config as jax_config
25+
from jax import config as jax_config
2626
import numpy as np
2727
import scipy.stats
2828
import tensorflow.compat.v2 as real_tf

spinoffs/fun_mc/fun_mc/malt_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# Dependency imports
2121

2222
import jax
23-
from jax.config import config as jax_config
23+
from jax import config as jax_config
2424
import numpy as np
2525
import tensorflow.compat.v2 as real_tf
2626

spinoffs/fun_mc/fun_mc/prefab_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# Dependency imports
2121

2222
import jax
23-
from jax.config import config as jax_config
23+
from jax import config as jax_config
2424
import numpy as np
2525
import tensorflow.compat.v2 as real_tf
2626

spinoffs/fun_mc/fun_mc/sga_hmc_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from absl.testing import parameterized
2323
import jax
24-
from jax.config import config as jax_config
24+
from jax import config as jax_config
2525
import tensorflow.compat.v2 as real_tf
2626

2727
from tensorflow_probability.python.internal import test_util as tfp_test_util

spinoffs/fun_mc/fun_mc/util_tfp_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# Dependency imports
1818

1919
from absl.testing import parameterized
20-
from jax.config import config as jax_config
20+
from jax import config as jax_config
2121
import numpy as np
2222
import tensorflow.compat.v2 as real_tf
2323

tensorflow_probability/examples/jupyter_notebooks/Structural_Time_Series_Modeling_Case_Studies_Atmospheric_CO2_and_Electricity_Demand_JAX.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@
9595
"\n",
9696
"import numpy as np\n",
9797
"import jax\n",
98-
"from jax.config import config\n",
98+
"from jax import config\n",
9999
"config.update('jax_enable_x64', True)\n",
100100
"\n",
101101
"from tensorflow_probability.substrates import jax as tfp\n",

tensorflow_probability/examples/jupyter_notebooks/TFP_Release_Notebook_0_11_0.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@
143143
},
144144
"source": [
145145
"import jax\n",
146-
"from jax.config import config\n",
146+
"from jax import config\n",
147147
"config.update('jax_enable_x64', True)\n",
148148
"\n",
149149
"def demo_jax():\n",

tensorflow_probability/python/internal/samplers_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def setUp(self):
3737
super().setUp()
3838

3939
if JAX_MODE and FLAGS.test_tfp_jax_prng != 'default':
40-
from jax.config import config # pylint: disable=g-import-not-at-top
40+
from jax import config # pylint: disable=g-import-not-at-top
4141
config.update('jax_default_prng_impl', FLAGS.test_tfp_jax_prng)
4242

4343
@test_util.substrate_disable_stateful_random_test

tensorflow_probability/python/internal/test_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2021,7 +2021,7 @@ def getTestCaseNames(self, testCaseClass): # pylint:disable=invalid-name
20212021
def main(jax_mode=JAX_MODE, jax_enable_x64=True):
20222022
"""Test main function that injects a custom loader."""
20232023
if jax_mode and jax_enable_x64:
2024-
from jax.config import config # pylint: disable=g-import-not-at-top
2024+
from jax import config # pylint: disable=g-import-not-at-top
20252025
config.update('jax_enable_x64', True)
20262026

20272027
# This logic is borrowed from TensorFlow.

0 commit comments

Comments
 (0)