Skip to content

Commit 0435c36

Browse files
vanderplastensorflower-gardener
authored andcommitted
test_util: use jax.random.clone when available
PiperOrigin-RevId: 614035309
1 parent 0af6f41 commit 0435c36

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tensorflow_probability/python/internal/test_util.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,12 +1546,16 @@ def test_seed(hardcoded_seed=None,
15461546

15471547
def clone_seed(seed):
15481548
"""Clone a seed: this is useful for JAX's experimental key reuse checking."""
1549-
# TODO(b/328085305): switch to standard clone API when possible.
15501549
if JAX_MODE:
15511550
import jax # pylint: disable=g-import-not-at-top
1552-
return jax.random.wrap_key_data(
1553-
jax.random.key_data(seed), impl=jax.random.key_impl(seed)
1554-
)
1551+
if hasattr(jax.random, 'clone'):
1552+
# jax v0.4.26 or later
1553+
return jax.random.clone(seed)
1554+
else:
1555+
# older jax versions
1556+
return jax.random.wrap_key_data(
1557+
jax.random.key_data(seed), impl=jax.random.key_impl(seed)
1558+
)
15551559
return seed
15561560

15571561

0 commit comments

Comments
 (0)