Skip to content

Commit 7911fc6

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
FunMC: Add HMC with state gradients.
This is a composable component of implementing gradient-based trajectory adaptation, such as ChEES-HMC. PiperOrigin-RevId: 377341695
1 parent e888f56 commit 7911fc6

File tree

5 files changed

+225
-9
lines changed

5 files changed

+225
-9
lines changed

spinoffs/fun_mc/fun_mc/dynamic/backend_jax/backend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
from fun_mc.dynamic.backend_jax import tf_on_jax
1818
from fun_mc.dynamic.backend_jax import util
1919
from tensorflow_probability.substrates import jax as tfp
20+
from tensorflow_probability.substrates.jax.internal import distribute_lib
2021
from tensorflow_probability.substrates.jax.internal import prefer_static
2122

2223
tf = tf_on_jax.tf
2324

2425
__all__ = [
26+
'distribute_lib',
2527
'prefer_static',
2628
'tf',
2729
'tfp',

spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/backend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616

1717
import tensorflow.compat.v2 as tf
1818
import tensorflow_probability as tfp
19+
from tensorflow_probability.python.internal import distribute_lib
1920
from tensorflow_probability.python.internal import prefer_static
2021
from fun_mc.dynamic.backend_tensorflow import util
2122

2223
__all__ = [
24+
'distribute_lib',
2325
'prefer_static',
2426
'tf',
2527
'tfp',

spinoffs/fun_mc/fun_mc/fun_mc_lib.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@
136136
Mapping[Any, BooleanTensor]]
137137
FloatNest = Union[FloatTensor, Sequence[FloatTensor], Mapping[Any, FloatTensor]]
138138
IntNest = Union[IntTensor, Sequence[IntTensor], Mapping[Any, IntTensor]]
139+
StringNest = Union[Text, Sequence[Text], Mapping[Any, Text]]
139140
DTypeNest = Union['tf.DType', Sequence['tf.DType'], Mapping[Any, 'tf.DType']]
140141
State = TensorNest # pylint: disable=invalid-name
141142
TransitionOperator = Callable[..., Tuple[State, TensorNest]]
@@ -1381,6 +1382,7 @@ class IntegratorExtras(NamedTuple):
13811382
final_kinetic_energy: 'FloatTensor'
13821383
final_kinetic_energy_extra: Any
13831384
integrator_trace: Any
1385+
momentum_grads: 'State'
13841386

13851387

13861388
@util.named_call
@@ -1496,7 +1498,8 @@ def integrator_trace_wrapper_fn(args, _):
14961498
final_energy=final_energy,
14971499
final_kinetic_energy=integrator_step_extra.kinetic_energy,
14981500
final_kinetic_energy_extra=integrator_step_extra.kinetic_energy_extra,
1499-
integrator_trace=integrator_trace)
1501+
integrator_trace=integrator_trace,
1502+
momentum_grads=integrator_step_extra.momentum_grads)
15001503

15011504
return state, extra
15021505

spinoffs/fun_mc/fun_mc/prefab.py

Lines changed: 123 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
'adaptive_hamiltonian_monte_carlo_init',
3838
'adaptive_hamiltonian_monte_carlo_step',
3939
'AdaptiveHamiltonianMonteCarloState',
40+
'hamiltonian_monte_carlo_with_state_grads_step',
41+
'HamiltonianMonteCarloWithStateGradsExtra',
4042
'interactive_trace',
4143
'step_size_adaptation_init',
4244
'step_size_adaptation_step',
@@ -457,8 +459,8 @@ def interactive_trace(
457459
iteration_axis: Integer. Indicates the axis of the trace outputs that should
458460
be flattened with the first axis. This is most useful when `fn` is
459461
`trace`. E.g. if the trace has shape `[num_steps, 2, 5]` and
460-
`iteration_axis=2`, the trace outputs will be reshaped/transposed to
461-
`[2, 5 * num_steps]`. A value of 0 disables this operation.
462+
`iteration_axis=2`, the trace outputs will be reshaped/transposed to `[2,
463+
5 * num_steps]`. A value of 0 disables this operation.
462464
block_until_ready: Whether to wait for the computation to finish between
463465
steps. This results in smoother progress bars under, e.g., JAX.
464466
progress_bar_fn: A callable that will be called with an iterable with length
@@ -504,13 +506,15 @@ def fn_with_progress(state):
504506
)
505507

506508
if iteration_axis != 0:
509+
507510
def fix_part(x):
508511
x = util.move_axis(x, 0, iteration_axis - 1)
509512
x = tf.reshape(
510513
x,
511514
tuple(x.shape[:iteration_axis - 1]) + (-1,) +
512515
tuple(x.shape[iteration_axis + 1:]))
513516
return x
517+
514518
trace = util.map_tree(fix_part, trace)
515519
return state, trace
516520

@@ -649,3 +653,120 @@ def step_size_adaptation_step(
649653
opt_state=opt_state, rms_state=rms_state, step=state.step + 1)
650654
extra = StepSizeAdaptationExtra(opt_extra=opt_extra, accept_prob=accept_prob)
651655
return state, extra
656+
657+
658+
class HamiltonianMonteCarloWithStateGradsExtra(NamedTuple):
659+
"""Extra outputs for 'hamiltonian_monte_carlo_with_state_grads_step'."""
660+
hmc_extra: 'fun_mc.HamiltonianMonteCarloExtra'
661+
num_integrator_steps: 'fun_mc.IntTensor'
662+
proposed_state: 'fun_mc.State'
663+
664+
665+
def hamiltonian_monte_carlo_with_state_grads_step(
666+
hmc_state: 'fun_mc.HamiltonianMonteCarloState',
667+
trajectory_length: 'fun_mc.FloatTensor',
668+
scalar_step_size: 'fun_mc.FloatTensor',
669+
step_size_scale: 'fun_mc.FloatNest' = 1.,
670+
shard_axis_names: 'fun_mc.StringNest' = (),
671+
**hmc_kwargs
672+
) -> ('Tuple[fun_mc.HamiltonianMonteCarloState, '
673+
'HamiltonianMonteCarloWithStateGradsExtra]'):
674+
"""Hamiltonian Monte Carlo (HMC) step with gradients for proposed state.
675+
676+
This acts as a `fun_mc.hamiltonian_monte_carlo_step`, where the
677+
`num_integrator_steps` is defined as `ceil(trajectory_length /
678+
scalar_step_size)` and `step_size` is defined as `scalar_step_size *
679+
step_size_scale`. The main feature of this function is that it propagates the
680+
gradients from `hmc_with_state_grads_extra.proposed_state` to
681+
`trajectory_length` (these are the only gradients propagated at the moment).
682+
This feature can be used to do gradient-based optimization of
683+
`trajectory_length` based on criteria that depend on the `proposed_state`
684+
(e.g. [1]).
685+
686+
This function supports SPMD via sharded states in the same sense as TensorFlow
687+
Probability's `tfp.experimental.distribute.Sharded`. Certain state tensors can
688+
be annotated as having different values on different devices, with
689+
cross-device reductions being inserted accordingly.
690+
691+
Args:
692+
hmc_state: `fun_mc.HamiltonianMonteCarloState`.
693+
trajectory_length: Trajectory length used by HMC.
694+
scalar_step_size: Scalar step size (used to compute the number of leapfrog
695+
steps).
696+
step_size_scale: Step size scale, structure broadcastable to the
697+
`hmc_state.state`.
698+
shard_axis_names: Shard axes names, used for SPMD.
699+
**hmc_kwargs: Passed to `fun_mc.hamiltonian_monte_carlo_step`.
700+
701+
Returns:
702+
hmc_state: `fun_mc.HamiltonianMonteCarloState`.
703+
hmc_with_grads_extra: Extra outputs.
704+
705+
#### References
706+
707+
[1]: Hoffman, M., Radul, A., & Sountsov, P. (2021). An Adaptive MCMC Scheme
708+
for Setting Trajectory Lengths in Hamiltonian Monte Carlo.
709+
http://proceedings.mlr.press/v130/hoffman21a.html
710+
"""
711+
712+
@tf.custom_gradient
713+
def hmc(trajectory_length):
714+
trajectory_length = tf.convert_to_tensor(trajectory_length)
715+
num_integrator_steps = tf.cast(
716+
tf.math.ceil(trajectory_length / scalar_step_size), tf.int32)
717+
# In case something goes negative.
718+
num_integrator_steps = tf.maximum(1, num_integrator_steps)
719+
new_hmc_state, hmc_extra = fun_mc.hamiltonian_monte_carlo_step(
720+
hmc_state,
721+
num_integrator_steps=num_integrator_steps,
722+
step_size=util.map_tree(lambda s: s * scalar_step_size,
723+
step_size_scale),
724+
**hmc_kwargs)
725+
hmc_with_grads_extra = HamiltonianMonteCarloWithStateGradsExtra(
726+
proposed_state=hmc_extra.proposed_hmc_state.state,
727+
hmc_extra=hmc_extra,
728+
num_integrator_steps=num_integrator_steps)
729+
res = (new_hmc_state, hmc_with_grads_extra)
730+
731+
def grad(*grads):
732+
grads = util.unflatten_tree(res, util.flatten_tree(grads))
733+
734+
step_size_scale_bc = fun_mc.maybe_broadcast_structure(
735+
step_size_scale, hmc_extra.integrator_extra.momentum_grads)
736+
737+
# We wish to compute `grads^T @
738+
# jacobian(proposed_state(trajectory_length))`.
739+
#
740+
# The Jacobian is known from from Hamilton's equations:
741+
#
742+
# dx / dt = dK(v) / dv
743+
#
744+
# where `x` is the state, `v` is the momentum and `K` is the kinetic
745+
# energy. Since `step_size_scale` rescales momentum, we the right hand
746+
# side of that expression is `momentum_grads * step_size_scale` by the
747+
# chain rule. Since the Jacobian in question has 1 row, the
748+
# vector-Jacobian product is simply the dot product.
749+
state_grads = util.map_tree(lambda s, m, g: s * m * g, step_size_scale_bc,
750+
hmc_extra.integrator_extra.momentum_grads,
751+
grads[1].proposed_state)
752+
753+
def do_sum(x, shard_axis_names):
754+
res = tf.reduce_sum(
755+
x, list(range(len(trajectory_length.shape), len(x.shape))))
756+
if shard_axis_names:
757+
res = backend.distribute_lib.psum(res, shard_axis_names)
758+
return res
759+
760+
if shard_axis_names:
761+
shard_axis_names_bc = shard_axis_names
762+
else:
763+
shard_axis_names_bc = util.map_tree(lambda _: [], state_grads)
764+
765+
return sum(
766+
util.flatten_tree(
767+
util.map_tree_up_to(state_grads, do_sum, state_grads,
768+
shard_axis_names_bc)))
769+
770+
return res, grad
771+
772+
return hmc(trajectory_length)

spinoffs/fun_mc/fun_mc/prefab_test.py

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414
# ============================================================================
1515
"""Tests for prefabs."""
1616

17+
import functools
18+
import os
19+
1720
# Dependency imports
1821

22+
import jax
1923
from jax.config import config as jax_config
2024
import numpy as np
2125
import tensorflow.compat.v2 as real_tf
@@ -30,13 +34,17 @@
3034
tf = backend.tf
3135
tfp = backend.tfp
3236
util = backend.util
37+
tfd = tfp.distributions
3338

3439
real_tf.enable_v2_behavior()
3540
jax_config.update('jax_enable_x64', True)
3641

37-
3842
BACKEND = None # Rewritten by backends/rewrite.py.
3943

44+
if BACKEND == 'backend_jax':
45+
os.environ['XLA_FLAGS'] = (f'{os.environ.get("XLA_FLAGS", "")} '
46+
'--xla_force_host_platform_device_count=4')
47+
4048

4149
def _test_seed():
4250
return tfp_test_util.test_seed() % (2**32 - 1)
@@ -169,6 +177,7 @@ def kernel(ssa_state, seed):
169177
self.assertAllClose(rms_step_size[100], rms_step_size[150])
170178

171179
def testInteractiveIterationAxis1(self):
180+
172181
def kernel(x):
173182
return x + 1, x
174183

@@ -184,6 +193,7 @@ def kernel(x):
184193
self.assertAllClose(99., trace[-1])
185194

186195
def testInteractiveIterationAxis2(self):
196+
187197
def kernel(x):
188198
return x + 1, x
189199

@@ -193,16 +203,94 @@ def inner(x):
193203
return state, trace
194204

195205
state, trace = prefab.interactive_trace(
196-
tf.zeros(2),
197-
inner,
198-
20,
199-
iteration_axis=2,
200-
progress_bar_fn=None)
206+
tf.zeros(2), inner, 20, iteration_axis=2, progress_bar_fn=None)
201207

202208
self.assertAllClose([100., 100.], state)
203209
self.assertEqual([2, 100], list(trace.shape))
204210
self.assertAllClose([99., 99.], trace[:, -1])
205211

212+
def testHMCWithStateGrads(self):
213+
trajectory_length = 1.
214+
epsilon = 1e-3
215+
216+
root = tfp.experimental.distribute.JointDistributionCoroutine.Root
217+
218+
seed = self._make_seed(_test_seed())
219+
220+
def hmc_step(trajectory_length, axis_name=()):
221+
222+
@tfp.experimental.distribute.JointDistributionCoroutine
223+
def model():
224+
z = yield root(tfd.Normal(0., 1))
225+
yield tfp.experimental.distribute.Sharded(
226+
tfd.Sample(tfd.Normal(z, 1.), 8), axis_name)
227+
228+
@tfp.experimental.distribute.JointDistributionCoroutine
229+
def momentum_dist():
230+
yield root(tfd.Normal(0., 2))
231+
yield root(
232+
tfp.experimental.distribute.Sharded(
233+
tfd.Sample(tfd.Normal(0., 3.), 8), axis_name))
234+
235+
def target_log_prob_fn(x):
236+
return model.log_prob(x), ()
237+
238+
def kinetic_energy_fn(m):
239+
return -momentum_dist.log_prob(m), ()
240+
241+
def momentum_sample_fn(seed):
242+
return momentum_dist.sample(2, seed=seed)
243+
244+
state = model.sample(2, seed=seed)
245+
hmc_state = fun_mc.hamiltonian_monte_carlo_init(state, target_log_prob_fn)
246+
hmc_state, hmc_extra = (
247+
prefab.hamiltonian_monte_carlo_with_state_grads_step(
248+
hmc_state,
249+
trajectory_length=trajectory_length,
250+
scalar_step_size=epsilon,
251+
step_size_scale=util.map_tree(lambda x: 1. + tf.abs(x), state),
252+
target_log_prob_fn=target_log_prob_fn,
253+
seed=seed,
254+
kinetic_energy_fn=kinetic_energy_fn,
255+
momentum_sample_fn=momentum_sample_fn,
256+
shard_axis_names=model.experimental_shard_axis_names))
257+
258+
def sum_state(x, axis_name):
259+
res = tf.reduce_sum(x**2)
260+
if axis_name:
261+
res = backend.distribute_lib.psum(res, axis_name)
262+
return res
263+
264+
sum_sq = util.map_tree_up_to(hmc_extra.proposed_state, sum_state,
265+
hmc_extra.proposed_state,
266+
model.experimental_shard_axis_names)
267+
sum_sq = sum(util.flatten_tree(sum_sq))
268+
return sum_sq, ()
269+
270+
def finite_diff_grad(f, epsilon, x):
271+
return (fun_mc.call_potential_fn(f, util.map_tree(
272+
lambda x: x + epsilon, x))[0] - fun_mc.call_potential_fn(
273+
f, util.map_tree(lambda x: x - epsilon, x))[0]) / (2 * epsilon)
274+
275+
f = tf.function(hmc_step)
276+
auto_diff = util.value_and_grad(f, trajectory_length)[2]
277+
finite_diff = finite_diff_grad(f, epsilon, trajectory_length)
278+
279+
self.assertAllClose(auto_diff, finite_diff, rtol=0.01)
280+
281+
if BACKEND == 'backend_jax':
282+
283+
@functools.partial(jax.pmap, axis_name='i')
284+
def run(_):
285+
f = tf.function(lambda trajectory_length: hmc_step( # pylint: disable=g-long-lambda
286+
trajectory_length, axis_name='i'))
287+
auto_diff = util.value_and_grad(f, trajectory_length)[2]
288+
finite_diff = finite_diff_grad(f, epsilon, trajectory_length)
289+
return auto_diff, finite_diff
290+
291+
auto_diff, finite_diff = run(tf.ones(4))
292+
self.assertAllClose(auto_diff, finite_diff, rtol=0.01)
293+
206294

207295
@test_util.multi_backend_test(globals(), 'prefab_test')
208296
class PrefabTest32(PrefabTest):

0 commit comments

Comments
 (0)