Skip to content

Commit 49ddde8

Browse files
ColCarrolltensorflower-gardener
authored andcommitted
In windowed mcmc samplers, use harmonic mean when reducing step size across chains, instead of arithmetic.
PiperOrigin-RevId: 375837008
1 parent 2f18733 commit 49ddde8

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

tensorflow_probability/python/experimental/mcmc/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,7 @@ multi_substrate_py_library(
10101010
"//tensorflow_probability/python/internal:prefer_static",
10111011
"//tensorflow_probability/python/internal:samplers",
10121012
"//tensorflow_probability/python/internal:unnest",
1013+
"//tensorflow_probability/python/math:generic",
10131014
"//tensorflow_probability/python/mcmc:dual_averaging_step_size_adaptation",
10141015
"//tensorflow_probability/python/mcmc:sample",
10151016
],

tensorflow_probability/python/experimental/mcmc/windowed_sampling.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from tensorflow_probability.python.internal import samplers
4141
from tensorflow_probability.python.internal import tensorshape_util
4242
from tensorflow_probability.python.internal import unnest
43+
from tensorflow_probability.python.math import generic as generic_math
4344
from tensorflow_probability.python.mcmc import dual_averaging_step_size_adaptation as dassa
4445
from tensorflow_probability.python.mcmc import sample
4546
from tensorflow.python.ops import control_flow_util # pylint: disable=g-direct-tensorflow-import
@@ -538,8 +539,9 @@ def windowed_adaptive_nuts(n_draws,
538539
work. Defaults to the dimension of the log density to the 0.25 power.
539540
dual_averaging_kwargs: Optional dict
540541
Keyword arguments to pass to `tfp.mcmc.DualAveragingStepSizeAdaptation`.
541-
By default, a `target_accept_prob` of 0.85 is set, and the class defaults
542-
are used otherwise.
542+
By default, a `target_accept_prob` of 0.85 is set, acceptance
543+
probabilities across chains are reduced using a harmonic mean, and the
544+
class defaults are used otherwise.
543545
max_tree_depth: Maximum depth of the tree implicitly built by NUTS. The
544546
maximum number of leapfrog steps is bounded by `2**max_tree_depth` i.e.
545547
the number of nodes in a binary tree `max_tree_depth` nodes deep. The
@@ -654,8 +656,9 @@ def windowed_adaptive_hmc(n_draws,
654656
work. Defaults to the dimension of the log density to the 0.25 power.
655657
dual_averaging_kwargs: Optional dict
656658
Keyword arguments to pass to `tfp.mcmc.DualAveragingStepSizeAdaptation`.
657-
By default, a `target_accept_prob` of 0.75 is set, and the class defaults
658-
are used otherwise.
659+
By default, a `target_accept_prob` of 0.75 is set, acceptance
660+
probabilities across chains are reduced using a harmonic mean, and the
661+
class defaults are used otherwise.
659662
trace_fn: Optional callable
660663
The trace function should accept the arguments
661664
`(state, bijector, is_adapting, phmc_kernel_results)`, where the `state`
@@ -736,6 +739,8 @@ def _windowed_adaptive_impl(n_draws,
736739
# trace_fn result allocation sizes.
737740
num_adaptation_steps = tf.convert_to_tensor(num_adaptation_steps)
738741

742+
dual_averaging_kwargs.setdefault('reduce_fn',
743+
generic_math.reduce_log_harmonic_mean_exp)
739744
setup_seed, init_seed, seed = samplers.split_seed(
740745
samplers.sanitize_seed(seed), n=3)
741746
(target_log_prob_fn, initial_transformed_position, bijector,

0 commit comments

Comments
 (0)