Skip to content

Commit e7733ce

Browse files
ColCarrolltensorflower-gardener
authored andcommitted
Fixed problem where the step size in NUTS and preconditioned NUTS could have a different structure than what the user input.
Specifically, when step size was a scalar, using `trace_fn`, or `return_final_kernel_results=True`, the step size would be a list. This did not happen in HMC (or preconditioned HMC), and this change makes the behavior consistent. PiperOrigin-RevId: 374703573
1 parent f318f9c commit e7733ce

File tree

4 files changed

+67
-37
lines changed

4 files changed

+67
-37
lines changed

tensorflow_probability/python/experimental/mcmc/pnuts_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,15 @@ def trace_fn(_, pkr):
589589
self.assertAllClose(
590590
average_rhat, np.ones_like(average_rhat), atol=0.05, rtol=0.05)
591591

592+
def test_step_size_trace(self):
593+
dist = tfd.Normal(0., 1.)
594+
kernel = tfp.experimental.mcmc.PreconditionedNoUTurnSampler(
595+
dist.log_prob, step_size=1.)
596+
_, _, fkr = tfp.mcmc.sample_chain(10, 0., kernel=kernel,
597+
return_final_kernel_results=True,
598+
seed=test_util.test_seed())
599+
self.assertAlmostEqual(1., self.evaluate(fkr.step_size))
600+
592601
# Allowed type of preconditioning schemes to use.
593602
# See code for details.
594603
PRECONDITION_SCHEMES = frozenset([

tensorflow_probability/python/experimental/mcmc/preconditioned_nuts.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,6 @@ def __init__(self,
270270

271271
# Process all other arguments.
272272
self._target_log_prob_fn = target_log_prob_fn
273-
if not tf.nest.is_nested(step_size):
274-
step_size = [step_size]
275273
self._step_size = step_size
276274

277275
self._parameters = dict(
@@ -409,12 +407,16 @@ def _copy(v):
409407
read_instruction=read_instruction
410408
)
411409

410+
step_size = _prepare_step_size(
411+
previous_kernel_results.step_size,
412+
current_target_log_prob.dtype,
413+
len(current_state))
412414
_, _, _, new_step_metastate = tf.while_loop(
413415
cond=lambda iter_, seed, state, metastate: ( # pylint: disable=g-long-lambda
414416
(iter_ < self.max_tree_depth) &
415417
tf.reduce_any(metastate.continue_tree)),
416418
body=lambda iter_, seed, state, metastate: self._loop_tree_doubling( # pylint: disable=g-long-lambda
417-
previous_kernel_results.step_size,
419+
step_size,
418420
previous_kernel_results.velocity_state_memory,
419421
current_step_meta_info,
420422
iter_,
@@ -466,21 +468,9 @@ def bootstrap_results(self, init_state):
466468
name='current_state')
467469
current_target_log_prob, current_grads_log_prob = mcmc_util.maybe_call_fn_and_grads(
468470
self.target_log_prob_fn, state_parts)
469-
# Padding the step_size so it is compatable with the states
470-
step_size = self.step_size
471-
if len(step_size) == 1:
472-
step_size = step_size * len(init_state)
473-
if len(step_size) != len(init_state):
474-
raise ValueError('Expected either one step size or {} (size of '
475-
'`init_state`), but found {}'.format(
476-
len(init_state), len(step_size)))
477-
step_size = tf.nest.map_structure(
478-
lambda x: tf.convert_to_tensor( # pylint: disable=g-long-lambda
479-
x,
480-
dtype=current_target_log_prob.dtype,
481-
name='step_size'),
482-
step_size)
483-
471+
# Confirm that the step size is compatible with the state parts.
472+
_ = _prepare_step_size(
473+
self.step_size, current_target_log_prob.dtype, len(init_state))
484474
momentum_distribution = self.momentum_distribution
485475
if momentum_distribution is None:
486476
momentum_distribution = pu.make_momentum_distribution(
@@ -508,7 +498,12 @@ def _init(shape_and_dtype):
508498
target_log_prob=current_target_log_prob,
509499
grads_target_log_prob=current_grads_log_prob,
510500
velocity_state_memory=velocity_state_memory,
511-
step_size=step_size,
501+
step_size=tf.nest.map_structure(
502+
lambda x: tf.convert_to_tensor( # pylint: disable=g-long-lambda
503+
x,
504+
dtype=current_target_log_prob.dtype,
505+
name='step_size'),
506+
self.step_size),
512507
log_accept_ratio=tf.zeros_like(
513508
current_target_log_prob, name='log_accept_ratio'),
514509
leapfrogs_taken=tf.zeros_like(
@@ -1110,3 +1105,14 @@ def compute_hamiltonian(target_log_prob, momentum_parts, momentum_distribution):
11101105
def get_kinetic_energy_fn(momentum_distribution):
11111106
"""Convert a momentum distribution to a kinetic energy function."""
11121107
return lambda *args: -momentum_distribution.log_prob(*args)
1108+
1109+
1110+
def _prepare_step_size(step_size, dtype, n_state_parts):
1111+
step_sizes, _ = mcmc_util.prepare_state_parts(
1112+
step_size, dtype=dtype, name='step_size')
1113+
if len(step_sizes) == 1:
1114+
step_sizes *= n_state_parts
1115+
if n_state_parts != len(step_sizes):
1116+
raise ValueError('There should be exactly one `step_size` or it should '
1117+
'have same length as `current_state`.')
1118+
return step_sizes

tensorflow_probability/python/mcmc/nuts.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,6 @@ def __init__(self,
262262

263263
# Process all other arguments.
264264
self._target_log_prob_fn = target_log_prob_fn
265-
if not tf.nest.is_nested(step_size):
266-
step_size = [step_size]
267265
self._step_size = step_size
268266

269267
self._parameters = dict(
@@ -400,12 +398,16 @@ def _copy(v):
400398
read_instruction=read_instruction
401399
)
402400

401+
step_size = _prepare_step_size(
402+
previous_kernel_results.step_size,
403+
current_target_log_prob.dtype,
404+
len(current_state))
403405
_, _, _, new_step_metastate = tf.while_loop(
404406
cond=lambda iter_, seed, state, metastate: ( # pylint: disable=g-long-lambda
405407
(iter_ < self.max_tree_depth) &
406408
tf.reduce_any(metastate.continue_tree)),
407409
body=lambda iter_, seed, state, metastate: self._loop_tree_doubling( # pylint: disable=g-long-lambda
408-
previous_kernel_results.step_size,
410+
step_size,
409411
previous_kernel_results.momentum_state_memory,
410412
current_step_meta_info,
411413
iter_,
@@ -472,26 +474,20 @@ def _init(shape_and_dtype):
472474
] = leapfrog_impl.process_args(self.target_log_prob_fn, dummy_momentum,
473475
init_state)
474476

475-
# Padding the step_size so it is compatable with the states
476-
step_size = self.step_size
477-
if len(step_size) == 1:
478-
step_size = step_size * len(init_state)
479-
if len(step_size) != len(init_state):
480-
raise ValueError('Expected either one step size or {} (size of '
481-
'`init_state`), but found {}'.format(
482-
len(init_state), len(step_size)))
483-
step_size = tf.nest.map_structure(
484-
lambda x: tf.convert_to_tensor( # pylint: disable=g-long-lambda
485-
x,
486-
dtype=current_target_log_prob.dtype,
487-
name='step_size'),
488-
step_size)
477+
# Confirm that the step size is compatible with the state parts.
478+
_ = _prepare_step_size(
479+
self.step_size, current_target_log_prob.dtype, len(init_state))
489480

490481
return NUTSKernelResults(
491482
target_log_prob=current_target_log_prob,
492483
grads_target_log_prob=current_grads_log_prob,
493484
momentum_state_memory=momentum_state_memory,
494-
step_size=step_size,
485+
step_size=tf.nest.map_structure(
486+
lambda x: tf.convert_to_tensor( # pylint: disable=g-long-lambda
487+
x,
488+
dtype=current_target_log_prob.dtype,
489+
name='step_size'),
490+
self.step_size),
495491
log_accept_ratio=tf.zeros_like(current_target_log_prob,
496492
name='log_accept_ratio'),
497493
leapfrogs_taken=tf.zeros_like(current_target_log_prob,
@@ -1080,6 +1076,17 @@ def generate_efficient_write_read_instruction(instruction_array):
10801076
return write_instruction, np.asarray(read_instruction)
10811077

10821078

1079+
def _prepare_step_size(step_size, dtype, n_state_parts):
1080+
step_sizes, _ = mcmc_util.prepare_state_parts(
1081+
step_size, dtype=dtype, name='step_size')
1082+
if len(step_sizes) == 1:
1083+
step_sizes *= n_state_parts
1084+
if n_state_parts != len(step_sizes):
1085+
raise ValueError('There should be exactly one `step_size` or it should '
1086+
'have same length as `current_state`.')
1087+
return step_sizes
1088+
1089+
10831090
def compute_hamiltonian(target_log_prob, momentum_parts,
10841091
shard_axis_names=None):
10851092
"""Compute the Hamiltonian of the current system."""

tensorflow_probability/python/mcmc/nuts_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,14 @@ def trace_fn(_, pkr):
574574
self.assertAllClose(
575575
average_rhat, np.ones_like(average_rhat), atol=0.05, rtol=0.05)
576576

577+
def test_step_size_trace(self):
578+
dist = tfd.Normal(0., 1.)
579+
kernel = tfp.mcmc.NoUTurnSampler(dist.log_prob, step_size=1.)
580+
_, _, fkr = tfp.mcmc.sample_chain(10, 0., kernel=kernel,
581+
return_final_kernel_results=True,
582+
seed=test_util.test_seed())
583+
self.assertAlmostEqual(1., self.evaluate(fkr.step_size))
584+
577585

578586
@test_util.test_all_tf_execution_regimes
579587
class DistributedNutsTest(distribute_test_lib.DistributedTest):

0 commit comments

Comments
 (0)