Skip to content

Commit f32c8d4

Browse files
vanderplastensorflower-gardener
authored andcommitted
minimize_stateless: avoid reusing initialization seed
Discovered by running tests with `jax_enable_key_reuse_checks=True`. PiperOrigin-RevId: 614737610
1 parent 5e568e1 commit f32c8d4

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tensorflow_probability/python/math/minimize.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ def run_jitted_minimize():
138138
seed_is_none = seed is None
139139
if not seed_is_none:
140140
seed = samplers.sanitize_seed(seed, salt='minimize')
141+
init_seed, seed = samplers.split_seed(seed, n=2)
142+
else:
143+
init_seed = None
141144

142145
if not return_full_length_trace:
143146
# Augment trace to record convergence info, so we can truncate it later.
@@ -153,7 +156,7 @@ def run_jitted_minimize():
153156
initial_optimizer_state) = optimizer_step_fn(
154157
parameters=initial_parameters,
155158
optimizer_state=initial_optimizer_state,
156-
seed=seed)
159+
seed=init_seed)
157160

158161
initial_convergence_criterion_state = ()
159162
if convergence_criterion is not None:

0 commit comments

Comments
 (0)