|
40 | 40 | from tensorflow_probability.python.internal import samplers
|
41 | 41 | from tensorflow_probability.python.internal import tensorshape_util
|
42 | 42 | from tensorflow_probability.python.internal import unnest
|
| 43 | +from tensorflow_probability.python.math import generic as generic_math |
43 | 44 | from tensorflow_probability.python.mcmc import dual_averaging_step_size_adaptation as dassa
|
44 | 45 | from tensorflow_probability.python.mcmc import sample
|
45 | 46 | from tensorflow.python.ops import control_flow_util # pylint: disable=g-direct-tensorflow-import
|
@@ -538,8 +539,9 @@ def windowed_adaptive_nuts(n_draws,
|
538 | 539 | work. Defaults to the dimension of the log density to the 0.25 power.
|
539 | 540 | dual_averaging_kwargs: Optional dict
|
540 | 541 | Keyword arguments to pass to `tfp.mcmc.DualAveragingStepSizeAdaptation`.
|
541 |
| - By default, a `target_accept_prob` of 0.85 is set, and the class defaults |
542 |
| - are used otherwise. |
| 542 | + By default, a `target_accept_prob` of 0.85 is set, acceptance |
| 543 | + probabilities across chains are reduced using a harmonic mean, and the |
| 544 | + class defaults are used otherwise. |
543 | 545 | max_tree_depth: Maximum depth of the tree implicitly built by NUTS. The
|
544 | 546 | maximum number of leapfrog steps is bounded by `2**max_tree_depth` i.e.
|
545 | 547 | the number of nodes in a binary tree `max_tree_depth` nodes deep. The
|
@@ -654,8 +656,9 @@ def windowed_adaptive_hmc(n_draws,
|
654 | 656 | work. Defaults to the dimension of the log density to the 0.25 power.
|
655 | 657 | dual_averaging_kwargs: Optional dict
|
656 | 658 | Keyword arguments to pass to `tfp.mcmc.DualAveragingStepSizeAdaptation`.
|
657 |
| - By default, a `target_accept_prob` of 0.75 is set, and the class defaults |
658 |
| - are used otherwise. |
| 659 | + By default, a `target_accept_prob` of 0.75 is set, acceptance |
| 660 | + probabilities across chains are reduced using a harmonic mean, and the |
| 661 | + class defaults are used otherwise. |
659 | 662 | trace_fn: Optional callable
|
660 | 663 | The trace function should accept the arguments
|
661 | 664 | `(state, bijector, is_adapting, phmc_kernel_results)`, where the `state`
|
@@ -736,6 +739,8 @@ def _windowed_adaptive_impl(n_draws,
|
736 | 739 | # trace_fn result allocation sizes.
|
737 | 740 | num_adaptation_steps = tf.convert_to_tensor(num_adaptation_steps)
|
738 | 741 |
|
| 742 | + dual_averaging_kwargs.setdefault('reduce_fn', |
| 743 | + generic_math.reduce_log_harmonic_mean_exp) |
739 | 744 | setup_seed, init_seed, seed = samplers.split_seed(
|
740 | 745 | samplers.sanitize_seed(seed), n=3)
|
741 | 746 | (target_log_prob_fn, initial_transformed_position, bijector,
|
|
0 commit comments