Skip to content

Commit 9a14b9b

Browse files
vanderplastensorflower-gardener
authored andcommitted
autoregressive: explicitly clone rng for reuse
PiperOrigin-RevId: 614079072
1 parent 0435c36 commit 9a14b9b

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

tensorflow_probability/python/distributions/autoregressive.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,18 @@ def _sample_n(self, n, seed=None):
294294
if num_steps_static is not None:
295295
for _ in range(num_steps_static):
296296
# pylint: disable=not-callable
297-
samples = self.distribution_fn(samples).sample(seed=seed)
297+
samples = self.distribution_fn(samples).sample(
298+
seed=samplers.clone_seed(seed)
299+
)
298300
else:
299301
# pylint: disable=not-callable
300-
samples = tf.foldl(lambda s, _: self.distribution_fn(s).sample(seed=seed),
301-
elems=tf.range(0, num_steps), initializer=samples)
302+
samples = tf.foldl(
303+
lambda s, _: self.distribution_fn(s).sample(
304+
seed=samplers.clone_seed(seed)
305+
),
306+
elems=tf.range(0, num_steps),
307+
initializer=samples,
308+
)
302309
return samples
303310

304311
def _log_prob(self, value):

tensorflow_probability/python/internal/samplers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
__all__ = [
3333
'categorical',
34+
'clone_seed',
3435
'fold_in',
3536
'gamma',
3637
'is_stateful_seed',
@@ -229,6 +230,16 @@ def split_seed(seed, n=2, salt=None, name=None):
229230
return seeds
230231

231232

233+
def clone_seed(seed):
234+
"""Clones a seed so it can be reused without causing a JAX KeyReuseError."""
235+
if JAX_MODE:
236+
from jax import random as jaxrand # pylint: disable=g-import-not-at-top
237+
if hasattr(jaxrand, 'clone'):
238+
# JAX v0.4.26+
239+
return jaxrand.clone(seed)
240+
return seed
241+
242+
232243
def categorical(
233244
logits,
234245
num_samples,

0 commit comments

Comments
 (0)