Skip to content

Commit e888f56

Browse files
ColCarrolltensorflower-gardener
authored andcommitted
Fix handling of scalar argument to init_step_size in windowed samplers.
This had been erroneously broadcast up to the shape of the state part, which meant there was no way to have a single step size shared between all state parts, other than relying on the default. PiperOrigin-RevId: 377331772
1 parent 85de528 commit e888f56

File tree

2 files changed

+23
-22
lines changed

2 files changed

+23
-22
lines changed

tensorflow_probability/python/experimental/mcmc/windowed_sampling.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,13 @@ def target_log_prob_fn(*args):
255255
return lp + ldj
256256

257257
def step_broadcast(step_size):
258-
return step_bijector(
259-
nest_util.broadcast_structure(pinned_model.event_shape_tensor(),
260-
step_size))
258+
# Only apply the bijector to nested step sizes or non-scalar batches.
259+
if tf.nest.is_nested(step_size):
260+
return step_bijector(
261+
nest_util.broadcast_structure(pinned_model.event_shape_tensor(),
262+
step_size))
263+
else:
264+
return step_size
261265

262266
return (target_log_prob_fn,
263267
initial_transformed_position,

tensorflow_probability/python/experimental/mcmc/windowed_sampling_test.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,7 @@ def model_fn():
428428
num_adaptation_steps=100, init_step_size=tf.ones([10, 1]),
429429
seed=test_util.test_seed()))
430430
self.assertEqual((2, 64, 10, 3), states.x.shape)
431-
self.assertLen(trace['step_size'], 1)
432-
self.assertEqual((2, 10, 1), trace['step_size'][0].shape)
431+
self.assertEqual((2, 10, 1), trace['step_size'].shape)
433432

434433
def test_batch_of_problems_named(self):
435434
use_multinomial = tf.executing_eagerly()
@@ -457,8 +456,7 @@ def mk_y(x):
457456
init_step_size=tf.ones([10, 1]),
458457
seed=test_util.test_seed()))
459458
self.assertEqual((2, 64, 10, 3), states['x'].shape)
460-
self.assertLen(trace['step_size'], 1)
461-
self.assertEqual((2, 10, 1), trace['step_size'][0].shape)
459+
self.assertEqual((2, 10, 1), trace['step_size'].shape)
462460

463461

464462
@test_util.test_graph_and_eager_modes
@@ -529,21 +527,20 @@ def test_supply_single_step_size(self):
529527
})
530528

531529
init_step_size = 1.
532-
_, actual_step_size = tfp.experimental.mcmc.windowed_adaptive_hmc(
533-
1,
534-
jd_model,
535-
num_adaptation_steps=100,
536-
n_chains=20,
537-
init_step_size=init_step_size,
538-
num_leapfrog_steps=5,
539-
discard_tuning=False,
540-
trace_fn=lambda *args: unnest.get_innermost(args[-1], 'step_size'),
541-
seed=stream(),
542-
)
543-
544-
actual_step = [j[0] for j in actual_step_size]
545-
expected_step = [1., 1.]
546-
self.assertAllCloseNested(expected_step, actual_step)
530+
_, traced_step_size = self.evaluate(
531+
tfp.experimental.mcmc.windowed_adaptive_hmc(
532+
1,
533+
jd_model,
534+
num_adaptation_steps=100,
535+
n_chains=20,
536+
init_step_size=init_step_size,
537+
num_leapfrog_steps=5,
538+
discard_tuning=False,
539+
trace_fn=lambda *args: unnest.get_innermost(args[-1], 'step_size'),
540+
seed=stream()))
541+
542+
self.assertEqual((100 + 1,), traced_step_size.shape)
543+
self.assertAllClose(1., traced_step_size[0])
547544

548545
def test_sequential_step_size(self):
549546
stream = test_util.test_seed_stream()

0 commit comments

Comments
 (0)