Skip to content

Commit 2c56628

Browse files
sharadmvtensorflower-gardener
authored andcommitted
BREAKING CHANGE: Stop unpacking seeds when splitting in JAX
Before this change `tfp.random.split_seed` would return a list of seeds instead. of an array of seeds, causing a gather for each seed. With large number of seeds, this can cause a slowdown in both trace and compile time. This change returns an array of seeds. If your code relies on using a list of seeds instead of an array, you can wrap the call to `tfp.random.split_seed` in a `list` or `jnp.unstack`. PiperOrigin-RevId: 386558360
1 parent fa952b3 commit 2c56628

File tree

6 files changed

+10
-9
lines changed

6 files changed

+10
-9
lines changed

tensorflow_probability/python/internal/samplers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,13 @@ def split_seed(seed, n=2, salt=None, name=None):
160160
161161
See https://github.com/tensorflow/probability/blob/main/PRNGS.md
162162
for details.
163-
164163
Args:
165164
seed: The seed to split; may be an `int`, an `(int, int) tuple`, or a
166165
`Tensor`. `int` seeds are converted to `Tensor` seeds using
167166
`tf.random.uniform` stateful sampling. Tuples are converted to `Tensor`.
168-
n: The number of splits to return.
167+
n: The number of splits to return. In TensorFlow, if `n` is an integer, this
168+
function returns a list of seeds and otherwise returns a `Tensor` of
169+
seeds. In JAX, this function always returns an array of seeds.
169170
salt: Optional `str` salt to mix with the seed.
170171
name: Optional name to scope related ops.
171172
@@ -184,7 +185,7 @@ def split_seed(seed, n=2, salt=None, name=None):
184185
seed = sanitize_seed(seed, salt=salt)
185186
if JAX_MODE:
186187
from jax import random as jaxrand # pylint: disable=g-import-not-at-top
187-
return list(jaxrand.split(seed, n))
188+
return jaxrand.split(seed, n)
188189
seeds = tf.random.stateless_uniform(
189190
[n, 2], seed=seed, minval=None, maxval=None, dtype=SEED_DTYPE)
190191
if isinstance(n, six.integer_types):

tensorflow_probability/python/mcmc/hmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ def one_step(self, current_state, previous_kernel_results, seed=None):
720720
state_gradients_are_stopped=self.state_gradients_are_stopped)
721721

722722
seed = samplers.sanitize_seed(seed) # Retain for diagnostics.
723-
seeds = samplers.split_seed(seed, n=len(current_state_parts))
723+
seeds = list(samplers.split_seed(seed, n=len(current_state_parts)))
724724
seeds = distribute_lib.fold_in_axis_index(
725725
seeds, self.experimental_shard_axis_names)
726726

tensorflow_probability/python/mcmc/langevin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,8 @@ def one_step(self, current_state, previous_kernel_results, seed=None):
455455
self.parallel_iterations)
456456

457457
seed = samplers.sanitize_seed(seed) # Retain for diagnostics.
458-
seeds = samplers.split_seed(
459-
seed, n=len(current_state_parts), salt='langevin.one_step')
458+
seeds = list(samplers.split_seed(
459+
seed, n=len(current_state_parts), salt='langevin.one_step'))
460460
seeds = distribute_lib.fold_in_axis_index(
461461
seeds, self.experimental_shard_axis_names)
462462

tensorflow_probability/python/mcmc/nuts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def experimental_with_shard_axes(self, shard_axis_names):
519519
def _start_trajectory_batched(self, state, target_log_prob, seed):
520520
"""Computations needed to start a trajectory."""
521521
with tf.name_scope('start_trajectory_batched'):
522-
seeds = samplers.split_seed(seed, n=len(state) + 1)
522+
seeds = list(samplers.split_seed(seed, n=len(state) + 1))
523523
momentum_seeds = distribute_lib.fold_in_axis_index(
524524
seeds[:-1], self.experimental_shard_axis_names)
525525
momentum = [

tensorflow_probability/python/mcmc/random_walk_metropolis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def _fn(state_parts, seed, experimental_shard_axis_names=None):
104104
if len(state_parts) != len(scales):
105105
raise ValueError('`scale` must broadcast with `state_parts`.')
106106

107-
part_seeds = samplers.split_seed(seed, n=len(state_parts))
107+
part_seeds = list(samplers.split_seed(seed, n=len(state_parts)))
108108
part_seeds = distribute_lib.fold_in_axis_index(
109109
part_seeds, experimental_shard_axis_names)
110110

tensorflow_probability/python/mcmc/slice_sampler_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def experimental_with_shard_axes(self, shard_axis_names):
351351
def _choose_random_direction(current_state_parts, batch_rank, seed=None,
352352
experimental_shard_axis_names=None):
353353
"""Chooses a random direction in the event space."""
354-
seeds = samplers.split_seed(seed, n=len(current_state_parts))
354+
seeds = list(samplers.split_seed(seed, n=len(current_state_parts)))
355355
seeds = distribute_lib.fold_in_axis_index(
356356
seeds, experimental_shard_axis_names)
357357
# Sample random directions across each of the input components.

0 commit comments

Comments
 (0)