Skip to content

Commit 3da47c7

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
FunMC: Add Persistent HMC.
PiperOrigin-RevId: 383953923
1 parent 2f3a878 commit 3da47c7

File tree

5 files changed

+544
-146
lines changed

5 files changed

+544
-146
lines changed

spinoffs/fun_mc/fun_mc/api.py

Lines changed: 4 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -13,156 +13,14 @@
1313
# limitations under the License.
1414
# ============================================================================
1515
"""FunMC API."""
16+
# pylint: disable=wildcard-import, unused-import
1617

18+
from fun_mc import fun_mc_lib
1719
from fun_mc import prefab
1820
from fun_mc import util_tfp
19-
from fun_mc.fun_mc_lib import adam_init
20-
from fun_mc.fun_mc_lib import adam_step
21-
from fun_mc.fun_mc_lib import AdamExtra
22-
from fun_mc.fun_mc_lib import AdamState
23-
from fun_mc.fun_mc_lib import blanes_3_stage_step
24-
from fun_mc.fun_mc_lib import blanes_4_stage_step
25-
from fun_mc.fun_mc_lib import call_fn
26-
from fun_mc.fun_mc_lib import call_potential_fn
27-
from fun_mc.fun_mc_lib import call_potential_fn_with_grads
28-
from fun_mc.fun_mc_lib import call_transition_operator
29-
from fun_mc.fun_mc_lib import call_transport_map
30-
from fun_mc.fun_mc_lib import call_transport_map_with_ldj
31-
from fun_mc.fun_mc_lib import choose
32-
from fun_mc.fun_mc_lib import gaussian_momentum_sample
33-
from fun_mc.fun_mc_lib import gaussian_proposal
34-
from fun_mc.fun_mc_lib import gradient_descent_init
35-
from fun_mc.fun_mc_lib import gradient_descent_step
36-
from fun_mc.fun_mc_lib import GradientDescentExtra
37-
from fun_mc.fun_mc_lib import GradientDescentState
38-
from fun_mc.fun_mc_lib import hamiltonian_integrator
39-
from fun_mc.fun_mc_lib import hamiltonian_monte_carlo_init
40-
from fun_mc.fun_mc_lib import hamiltonian_monte_carlo_step
41-
from fun_mc.fun_mc_lib import HamiltonianMonteCarloExtra
42-
from fun_mc.fun_mc_lib import HamiltonianMonteCarloState
43-
from fun_mc.fun_mc_lib import IntegratorExtras
44-
from fun_mc.fun_mc_lib import IntegratorState
45-
from fun_mc.fun_mc_lib import IntegratorStep
46-
from fun_mc.fun_mc_lib import IntegratorStepState
47-
from fun_mc.fun_mc_lib import leapfrog_step
48-
from fun_mc.fun_mc_lib import make_gaussian_kinetic_energy_fn
49-
from fun_mc.fun_mc_lib import make_surrogate_loss_fn
50-
from fun_mc.fun_mc_lib import maximal_reflection_coupling_proposal
51-
from fun_mc.fun_mc_lib import MaximalReflectiveCouplingProposalExtra
52-
from fun_mc.fun_mc_lib import maybe_broadcast_structure
53-
from fun_mc.fun_mc_lib import mclachlan_optimal_4th_order_step
54-
from fun_mc.fun_mc_lib import metropolis_hastings_step
55-
from fun_mc.fun_mc_lib import MetropolisHastingsExtra
56-
from fun_mc.fun_mc_lib import potential_scale_reduction_extract
57-
from fun_mc.fun_mc_lib import potential_scale_reduction_init
58-
from fun_mc.fun_mc_lib import potential_scale_reduction_step
59-
from fun_mc.fun_mc_lib import PotentialFn
60-
from fun_mc.fun_mc_lib import PotentialScaleReductionState
61-
from fun_mc.fun_mc_lib import random_walk_metropolis_init
62-
from fun_mc.fun_mc_lib import random_walk_metropolis_step
63-
from fun_mc.fun_mc_lib import RandomWalkMetropolisExtra
64-
from fun_mc.fun_mc_lib import RandomWalkMetropolisState
65-
from fun_mc.fun_mc_lib import recover_state_from_args
66-
from fun_mc.fun_mc_lib import reparameterize_potential_fn
67-
from fun_mc.fun_mc_lib import running_approximate_auto_covariance_init
68-
from fun_mc.fun_mc_lib import running_approximate_auto_covariance_step
69-
from fun_mc.fun_mc_lib import running_covariance_init
70-
from fun_mc.fun_mc_lib import running_covariance_step
71-
from fun_mc.fun_mc_lib import running_mean_init
72-
from fun_mc.fun_mc_lib import running_mean_step
73-
from fun_mc.fun_mc_lib import running_variance_init
74-
from fun_mc.fun_mc_lib import running_variance_step
75-
from fun_mc.fun_mc_lib import RunningApproximateAutoCovarianceState
76-
from fun_mc.fun_mc_lib import RunningCovarianceState
77-
from fun_mc.fun_mc_lib import RunningMeanState
78-
from fun_mc.fun_mc_lib import RunningVarianceState
79-
from fun_mc.fun_mc_lib import ruth4_step
80-
from fun_mc.fun_mc_lib import sign_adaptation
81-
from fun_mc.fun_mc_lib import simple_dual_averages_init
82-
from fun_mc.fun_mc_lib import simple_dual_averages_step
83-
from fun_mc.fun_mc_lib import SimpleDualAveragesExtra
84-
from fun_mc.fun_mc_lib import SimpleDualAveragesState
85-
from fun_mc.fun_mc_lib import splitting_integrator_step
86-
from fun_mc.fun_mc_lib import State
87-
from fun_mc.fun_mc_lib import trace
88-
from fun_mc.fun_mc_lib import transform_log_prob_fn
89-
from fun_mc.fun_mc_lib import TransitionOperator
90-
from fun_mc.fun_mc_lib import TransportMap
21+
from fun_mc.fun_mc_lib import *
9122

9223
__all__ = [
93-
'adam_init',
94-
'adam_step',
95-
'AdamExtra',
96-
'AdamState',
97-
'blanes_3_stage_step',
98-
'blanes_4_stage_step',
99-
'call_fn',
100-
'call_potential_fn',
101-
'call_potential_fn_with_grads',
102-
'call_transition_operator',
103-
'call_transport_map',
104-
'call_transport_map_with_ldj',
105-
'choose',
106-
'gaussian_momentum_sample',
107-
'gaussian_proposal',
108-
'gradient_descent_init',
109-
'gradient_descent_step',
110-
'GradientDescentExtra',
111-
'GradientDescentState',
112-
'hamiltonian_integrator',
113-
'hamiltonian_monte_carlo_init',
114-
'hamiltonian_monte_carlo_step',
115-
'HamiltonianMonteCarloExtra',
116-
'HamiltonianMonteCarloState',
117-
'IntegratorExtras',
118-
'IntegratorState',
119-
'IntegratorStep',
120-
'IntegratorStepState',
121-
'leapfrog_step',
122-
'make_gaussian_kinetic_energy_fn',
123-
'make_surrogate_loss_fn',
124-
'maximal_reflection_coupling_proposal',
125-
'MaximalReflectiveCouplingProposalExtra',
126-
'maybe_broadcast_structure',
127-
'mclachlan_optimal_4th_order_step',
128-
'metropolis_hastings_step',
129-
'MetropolisHastingsExtra',
130-
'potential_scale_reduction_extract',
131-
'potential_scale_reduction_init',
132-
'potential_scale_reduction_step',
133-
'PotentialFn',
134-
'PotentialScaleReductionState',
13524
'prefab',
136-
'random_walk_metropolis_init',
137-
'random_walk_metropolis_step',
138-
'RandomWalkMetropolisExtra',
139-
'RandomWalkMetropolisState',
140-
'reparameterize_potential_fn',
141-
'recover_state_from_args',
142-
'running_approximate_auto_covariance_init',
143-
'running_approximate_auto_covariance_step',
144-
'running_covariance_init',
145-
'running_covariance_step',
146-
'running_mean_init',
147-
'running_mean_step',
148-
'running_variance_init',
149-
'running_variance_step',
150-
'RunningApproximateAutoCovarianceState',
151-
'RunningCovarianceState',
152-
'RunningMeanState',
153-
'RunningVarianceState',
154-
'ruth4_step',
155-
'sign_adaptation',
156-
'simple_dual_averages_init',
157-
'simple_dual_averages_step',
158-
'SimpleDualAveragesExtra',
159-
'SimpleDualAveragesState',
160-
'splitting_integrator_step',
161-
'State',
162-
'trace',
163-
'transform_log_prob_fn',
164-
'TransitionOperator',
165-
'TransportMap',
16625
'util_tfp',
167-
]
168-
26+
] + fun_mc_lib.__all__

spinoffs/fun_mc/fun_mc/fun_mc_lib.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@
8484
'mclachlan_optimal_4th_order_step',
8585
'metropolis_hastings_step',
8686
'MetropolisHastingsExtra',
87+
'persistent_metropolis_hastings_init',
88+
'persistent_metropolis_hastings_step',
89+
'PersistentMetropolistHastingsExtra',
90+
'PersistentMetropolistHastingsState',
8791
'potential_scale_reduction_extract',
8892
'potential_scale_reduction_init',
8993
'potential_scale_reduction_step',
@@ -1083,6 +1087,97 @@ def metropolis_hastings_step(
10831087
is_accepted=is_accepted, log_uniform=log_uniform)
10841088

10851089

1090+
class PersistentMetropolistHastingsState(NamedTuple):
1091+
"""Persistent Metropolis Hastings state.
1092+
1093+
Attributes:
1094+
level: Value uniformly distributed on [-1, 1], absolute value of which is
1095+
used as the slice variable for the acceptance test.
1096+
"""
1097+
# We borrow the [-1, 1] encoding from the original paper; it has the effect of
1098+
# flipping the drift direction automatically, which has the effect of
1099+
# prolonging the persistent bouts of acceptance.
1100+
level: 'FloatTensor'
1101+
1102+
1103+
class PersistentMetropolistHastingsExtra(NamedTuple):
1104+
"""Persistent Metropolis Hastings extra outputs.
1105+
1106+
Attributes:
1107+
is_accepted: Whether the proposed state was accepted.
1108+
accepted_state: The accepted state.
1109+
"""
1110+
is_accepted: 'BooleanTensor'
1111+
accepted_state: 'State'
1112+
1113+
1114+
@util.named_call
1115+
def persistent_metropolis_hastings_init(
1116+
shape: 'IntTensor',
1117+
dtype: 'tf.DType' = tf.float32,
1118+
init_level: 'FloatTensor' = 0.,
1119+
) -> 'PersistentMetropolistHastingsState':
1120+
"""Initializes `PersistentMetropolistHastingsState`.
1121+
1122+
Args:
1123+
shape: Shape of the independent levels.
1124+
dtype: Dtype for the levels.
1125+
init_level: Initial value for the level. Broadcastable to `shape`.
1126+
1127+
Returns:
1128+
pmh_state: `PersistentMetropolistHastingsState`
1129+
"""
1130+
return PersistentMetropolistHastingsState(level=init_level +
1131+
tf.zeros(shape, dtype))
1132+
1133+
1134+
@util.named_call
1135+
def persistent_metropolis_hastings_step(
1136+
pmh_state: 'PersistentMetropolistHastingsState',
1137+
current_state: 'State',
1138+
proposed_state: 'State',
1139+
energy_change: 'FloatTensor',
1140+
drift: 'FloatTensor',
1141+
) -> ('Tuple[PersistentMetropolistHastingsState, '
1142+
'PersistentMetropolistHastingsExtra]'):
1143+
"""Persistent metropolis hastings step.
1144+
1145+
This implements the algorithm from [1]. The net effect of this algorithm is
1146+
that accepts/rejects are clustered in time, which helps algorithms that rely
1147+
on persistent momenta. The overall acceptance rate is unaffected. This
1148+
algorithm assumes that the `energy_change` has a continuous distribution
1149+
symmetric about 0 to maintain ergodicity.
1150+
1151+
Args:
1152+
pmh_state: `PersistentMetropolistHastingsState`
1153+
current_state: Current state.
1154+
proposed_state: Proposed state.
1155+
energy_change: E(proposed_state) - E(previous_state).
1156+
drift: How much to shift the level variable at each step.
1157+
1158+
Returns:
1159+
pmh_state: New `PersistentMetropolistHastingsState`.
1160+
pmh_extra: `PersistentMetropolistHastingsExtra`.
1161+
1162+
#### References
1163+
1164+
[1]: Neal, R. M. (2020). Non-reversibly updating a uniform [0,1] value for
1165+
Metropolis accept/reject decisions.
1166+
"""
1167+
log_accept_ratio = -energy_change
1168+
is_accepted = tf.math.log(tf.abs(pmh_state.level)) < log_accept_ratio
1169+
# N.B. we'll never accept when energy_change is NaN, so `level` should remain
1170+
# non-NaN at all times.
1171+
level = pmh_state.level
1172+
level = tf.where(is_accepted, level * tf.exp(energy_change), level)
1173+
level += drift
1174+
level = (1 + level) % 2 - 1
1175+
return pmh_state._replace(level=level), PersistentMetropolistHastingsExtra(
1176+
is_accepted=is_accepted,
1177+
accepted_state=choose(is_accepted, proposed_state, current_state),
1178+
)
1179+
1180+
10861181
MomentumSampleFn = Callable[[Any], State]
10871182

10881183

spinoffs/fun_mc/fun_mc/fun_mc_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,6 +1546,48 @@ def inverse_map_fn(x, y):
15461546
self.assertAllClose([2., 3.], orig_space)
15471547
self.assertAllClose(potential, transformed_potential)
15481548

1549+
def testPersistentMH(self):
1550+
1551+
def target_log_prob_fn(x):
1552+
return -x**2 / 2, ()
1553+
1554+
def kernel(pmh_state, rwm_state, seed):
1555+
seed, rwm_seed = util.split_seed(seed, 2)
1556+
# RWM is used to create a valid sequence of energy changes. The
1557+
# correctness of the algorithm relies on the energy changes to be
1558+
# symmetric about 0.
1559+
rwm_state, rwm_extra = fun_mc.random_walk_metropolis_step(
1560+
rwm_state,
1561+
target_log_prob_fn=target_log_prob_fn,
1562+
proposal_fn=lambda state, seed: fun_mc.gaussian_proposal( # pylint: disable=g-long-lambda
1563+
state, seed=seed),
1564+
seed=rwm_seed)
1565+
pmh_state, pmh_extra = fun_mc.persistent_metropolis_hastings_step(
1566+
pmh_state,
1567+
# Use dummy states for testing.
1568+
current_state=self._constant(0.),
1569+
proposed_state=self._constant(1.),
1570+
# Coprime with 1000 below.
1571+
drift=0.127,
1572+
energy_change=-rwm_extra.log_accept_ratio)
1573+
return (pmh_state, rwm_state,
1574+
seed), (pmh_extra.is_accepted, pmh_extra.accepted_state,
1575+
rwm_extra.is_accepted)
1576+
1577+
_, (pmh_is_accepted, pmh_accepted_state, rwm_is_accepted) = fun_mc.trace(
1578+
(fun_mc.persistent_metropolis_hastings_init([], self._dtype),
1579+
fun_mc.random_walk_metropolis_init(
1580+
self._constant(0.), target_log_prob_fn),
1581+
self._make_seed(_test_seed())), kernel, 1000)
1582+
1583+
pmh_is_accepted = tf.cast(pmh_is_accepted, self._dtype)
1584+
rwm_is_accepted = tf.cast(rwm_is_accepted, self._dtype)
1585+
self.assertAllClose(
1586+
tf.reduce_mean(rwm_is_accepted),
1587+
tf.reduce_mean(pmh_is_accepted),
1588+
atol=0.05)
1589+
self.assertAllClose(pmh_is_accepted, pmh_accepted_state)
1590+
15491591

15501592
@test_util.multi_backend_test(globals(), 'fun_mc_test')
15511593
class FunMCTest32(FunMCTest):

0 commit comments

Comments
 (0)