File tree Expand file tree Collapse file tree 1 file changed +8
-4
lines changed
tensorflow_probability/python/internal Expand file tree Collapse file tree 1 file changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -1546,12 +1546,16 @@ def test_seed(hardcoded_seed=None,
15461546
15471547def clone_seed (seed ):
15481548 """Clone a seed: this is useful for JAX's experimental key reuse checking."""
1549- # TODO(b/328085305): switch to standard clone API when possible.
15501549 if JAX_MODE :
15511550 import jax # pylint: disable=g-import-not-at-top
1552- return jax .random .wrap_key_data (
1553- jax .random .key_data (seed ), impl = jax .random .key_impl (seed )
1554- )
1551+ if hasattr (jax .random , 'clone' ):
1552+ # jax v0.4.26 or later
1553+ return jax .random .clone (seed )
1554+ else :
1555+ # older jax versions
1556+ return jax .random .wrap_key_data (
1557+ jax .random .key_data (seed ), impl = jax .random .key_impl (seed )
1558+ )
15551559 return seed
15561560
15571561
You can’t perform that action at this time.
0 commit comments