|
37 | 37 | 'adaptive_hamiltonian_monte_carlo_init',
|
38 | 38 | 'adaptive_hamiltonian_monte_carlo_step',
|
39 | 39 | 'AdaptiveHamiltonianMonteCarloState',
|
| 40 | + 'hamiltonian_monte_carlo_with_state_grads_step', |
| 41 | + 'HamiltonianMonteCarloWithStateGradsExtra', |
40 | 42 | 'interactive_trace',
|
41 | 43 | 'step_size_adaptation_init',
|
42 | 44 | 'step_size_adaptation_step',
|
@@ -457,8 +459,8 @@ def interactive_trace(
|
457 | 459 | iteration_axis: Integer. Indicates the axis of the trace outputs that should
|
458 | 460 | be flattened with the first axis. This is most useful when `fn` is
|
459 | 461 | `trace`. E.g. if the trace has shape `[num_steps, 2, 5]` and
|
460 |
| - `iteration_axis=2`, the trace outputs will be reshaped/transposed to |
461 |
| - `[2, 5 * num_steps]`. A value of 0 disables this operation. |
| 462 | + `iteration_axis=2`, the trace outputs will be reshaped/transposed to `[2, |
| 463 | + 5 * num_steps]`. A value of 0 disables this operation. |
462 | 464 | block_until_ready: Whether to wait for the computation to finish between
|
463 | 465 | steps. This results in smoother progress bars under, e.g., JAX.
|
464 | 466 | progress_bar_fn: A callable that will be called with an iterable with length
|
@@ -504,13 +506,15 @@ def fn_with_progress(state):
|
504 | 506 | )
|
505 | 507 |
|
506 | 508 | if iteration_axis != 0:
|
| 509 | + |
507 | 510 | def fix_part(x):
|
508 | 511 | x = util.move_axis(x, 0, iteration_axis - 1)
|
509 | 512 | x = tf.reshape(
|
510 | 513 | x,
|
511 | 514 | tuple(x.shape[:iteration_axis - 1]) + (-1,) +
|
512 | 515 | tuple(x.shape[iteration_axis + 1:]))
|
513 | 516 | return x
|
| 517 | + |
514 | 518 | trace = util.map_tree(fix_part, trace)
|
515 | 519 | return state, trace
|
516 | 520 |
|
@@ -649,3 +653,120 @@ def step_size_adaptation_step(
|
649 | 653 | opt_state=opt_state, rms_state=rms_state, step=state.step + 1)
|
650 | 654 | extra = StepSizeAdaptationExtra(opt_extra=opt_extra, accept_prob=accept_prob)
|
651 | 655 | return state, extra
|
| 656 | + |
| 657 | + |
| 658 | +class HamiltonianMonteCarloWithStateGradsExtra(NamedTuple): |
| 659 | + """Extra outputs for 'hamiltonian_monte_carlo_with_state_grads_step'.""" |
| 660 | + hmc_extra: 'fun_mc.HamiltonianMonteCarloExtra' |
| 661 | + num_integrator_steps: 'fun_mc.IntTensor' |
| 662 | + proposed_state: 'fun_mc.State' |
| 663 | + |
| 664 | + |
| 665 | +def hamiltonian_monte_carlo_with_state_grads_step( |
| 666 | + hmc_state: 'fun_mc.HamiltonianMonteCarloState', |
| 667 | + trajectory_length: 'fun_mc.FloatTensor', |
| 668 | + scalar_step_size: 'fun_mc.FloatTensor', |
| 669 | + step_size_scale: 'fun_mc.FloatNest' = 1., |
| 670 | + shard_axis_names: 'fun_mc.StringNest' = (), |
| 671 | + **hmc_kwargs |
| 672 | +) -> ('Tuple[fun_mc.HamiltonianMonteCarloState, ' |
| 673 | + 'HamiltonianMonteCarloWithStateGradsExtra]'): |
| 674 | + """Hamiltonian Monte Carlo (HMC) step with gradients for proposed state. |
| 675 | +
|
| 676 | + This acts as a `fun_mc.hamiltonian_monte_carlo_step`, where the |
| 677 | + `num_integrator_steps` is defined as `ceil(trajectory_length / |
| 678 | + scalar_step_size)` and `step_size` is defined as `scalar_step_size * |
| 679 | + step_size_scale`. The main feature of this function is that it propagates the |
| 680 | + gradients from `hmc_with_state_grads_extra.proposed_state` to |
| 681 | + `trajectory_length` (these are the only gradients propagated at the moment). |
| 682 | + This feature can be used to do gradient-based optimization of |
| 683 | + `trajectory_length` based on criteria that depend on the `proposed_state` |
| 684 | + (e.g. [1]). |
| 685 | +
|
| 686 | + This function supports SPMD via sharded states in the same sense as TensorFlow |
| 687 | + Probability's `tfp.experimental.distribute.Sharded`. Certain state tensors can |
| 688 | + be annotated as having different values on different devices, with |
| 689 | + cross-device reductions being inserted accordingly. |
| 690 | +
|
| 691 | + Args: |
| 692 | + hmc_state: `fun_mc.HamiltonianMonteCarloState`. |
| 693 | + trajectory_length: Trajectory length used by HMC. |
| 694 | + scalar_step_size: Scalar step size (used to compute the number of leapfrog |
| 695 | + steps). |
| 696 | + step_size_scale: Step size scale, structure broadcastable to the |
| 697 | + `hmc_state.state`. |
| 698 | + shard_axis_names: Shard axes names, used for SPMD. |
| 699 | + **hmc_kwargs: Passed to `fun_mc.hamiltonian_monte_carlo_step`. |
| 700 | +
|
| 701 | + Returns: |
| 702 | + hmc_state: `fun_mc.HamiltonianMonteCarloState`. |
| 703 | + hmc_with_grads_extra: Extra outputs. |
| 704 | +
|
| 705 | + #### References |
| 706 | +
|
| 707 | + [1]: Hoffman, M., Radul, A., & Sountsov, P. (2021). An Adaptive MCMC Scheme |
| 708 | + for Setting Trajectory Lengths in Hamiltonian Monte Carlo. |
| 709 | + http://proceedings.mlr.press/v130/hoffman21a.html |
| 710 | + """ |
| 711 | + |
| 712 | + @tf.custom_gradient |
| 713 | + def hmc(trajectory_length): |
| 714 | + trajectory_length = tf.convert_to_tensor(trajectory_length) |
| 715 | + num_integrator_steps = tf.cast( |
| 716 | + tf.math.ceil(trajectory_length / scalar_step_size), tf.int32) |
| 717 | + # In case something goes negative. |
| 718 | + num_integrator_steps = tf.maximum(1, num_integrator_steps) |
| 719 | + new_hmc_state, hmc_extra = fun_mc.hamiltonian_monte_carlo_step( |
| 720 | + hmc_state, |
| 721 | + num_integrator_steps=num_integrator_steps, |
| 722 | + step_size=util.map_tree(lambda s: s * scalar_step_size, |
| 723 | + step_size_scale), |
| 724 | + **hmc_kwargs) |
| 725 | + hmc_with_grads_extra = HamiltonianMonteCarloWithStateGradsExtra( |
| 726 | + proposed_state=hmc_extra.proposed_hmc_state.state, |
| 727 | + hmc_extra=hmc_extra, |
| 728 | + num_integrator_steps=num_integrator_steps) |
| 729 | + res = (new_hmc_state, hmc_with_grads_extra) |
| 730 | + |
| 731 | + def grad(*grads): |
| 732 | + grads = util.unflatten_tree(res, util.flatten_tree(grads)) |
| 733 | + |
| 734 | + step_size_scale_bc = fun_mc.maybe_broadcast_structure( |
| 735 | + step_size_scale, hmc_extra.integrator_extra.momentum_grads) |
| 736 | + |
| 737 | + # We wish to compute `grads^T @ |
| 738 | + # jacobian(proposed_state(trajectory_length))`. |
| 739 | + # |
| 740 | + # The Jacobian is known from from Hamilton's equations: |
| 741 | + # |
| 742 | + # dx / dt = dK(v) / dv |
| 743 | + # |
| 744 | + # where `x` is the state, `v` is the momentum and `K` is the kinetic |
| 745 | + # energy. Since `step_size_scale` rescales momentum, we the right hand |
| 746 | + # side of that expression is `momentum_grads * step_size_scale` by the |
| 747 | + # chain rule. Since the Jacobian in question has 1 row, the |
| 748 | + # vector-Jacobian product is simply the dot product. |
| 749 | + state_grads = util.map_tree(lambda s, m, g: s * m * g, step_size_scale_bc, |
| 750 | + hmc_extra.integrator_extra.momentum_grads, |
| 751 | + grads[1].proposed_state) |
| 752 | + |
| 753 | + def do_sum(x, shard_axis_names): |
| 754 | + res = tf.reduce_sum( |
| 755 | + x, list(range(len(trajectory_length.shape), len(x.shape)))) |
| 756 | + if shard_axis_names: |
| 757 | + res = backend.distribute_lib.psum(res, shard_axis_names) |
| 758 | + return res |
| 759 | + |
| 760 | + if shard_axis_names: |
| 761 | + shard_axis_names_bc = shard_axis_names |
| 762 | + else: |
| 763 | + shard_axis_names_bc = util.map_tree(lambda _: [], state_grads) |
| 764 | + |
| 765 | + return sum( |
| 766 | + util.flatten_tree( |
| 767 | + util.map_tree_up_to(state_grads, do_sum, state_grads, |
| 768 | + shard_axis_names_bc))) |
| 769 | + |
| 770 | + return res, grad |
| 771 | + |
| 772 | + return hmc(trajectory_length) |
0 commit comments