Skip to content

Commit 40a2b25

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Adaptive MALT: Also adapt the normalization power.
PiperOrigin-RevId: 474429276
1 parent 91c45f7 commit 40a2b25

File tree

1 file changed

+88
-11
lines changed

1 file changed

+88
-11
lines changed

discussion/adaptive_malt/adaptive_malt.py

Lines changed: 88 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,19 @@ def snaper_criterion(
120120
x_center = x_center * (1 - mw) + mx * mw
121121
proposed_state = x - jax.lax.stop_gradient(x_center)
122122

123-
previous_state = jnp.einsum('d,nd->n', principal, previous_state)
124-
proposed_state = jnp.einsum('d,nd->n', principal, proposed_state)
123+
previous_projection = jnp.einsum('d,nd->n', principal, previous_state)
124+
proposed_projection = jnp.einsum('d,nd->n', principal, proposed_state)
125125

126-
esjd = ((previous_state**2 - proposed_state**2)**2)
126+
esjd = ((previous_projection**2 - proposed_projection**2)**2)
127127

128128
esjd = jnp.where(accept_prob > 1e-4, esjd, 0.)
129129
accept_prob = accept_prob / jnp.sum(accept_prob + 1e-20)
130130
esjd = esjd * accept_prob
131131

132-
return esjd.mean() / trajectory_length**power, ()
132+
return esjd.mean() / trajectory_length**power, {
133+
'previous_projection': previous_projection,
134+
'proposed_projection': proposed_projection,
135+
}
133136

134137

135138
class AdaptiveMCMCState(NamedTuple):
@@ -138,6 +141,7 @@ class AdaptiveMCMCState(NamedTuple):
138141
Attributes:
139142
mcmc_state: MCMC state.
140143
rvar_state: Running variance of the state.
144+
proj_rautocov_state: Running lag-1 auto covariance of the projections.
141145
principal_rmean_state: Running mean of the unnormalized principal
142146
components.
143147
log_step_size_opt_state: Optimizer state for the log step size.
@@ -150,6 +154,7 @@ class AdaptiveMCMCState(NamedTuple):
150154
mcmc_state: Union[fun_mc.HamiltonianMonteCarloState,
151155
fun_mc.prefab.MetropolisAdjustedLangevinTrajectoriesState]
152156
rvar_state: fun_mc.RunningVarianceState
157+
proj_rautocov_state: fun_mc.RunningCovarianceState
153158
principal_rmean_state: fun_mc.RunningMeanState
154159
log_step_size_opt_state: fun_mc.AdamState
155160
log_trajectory_length_opt_state: fun_mc.AdamState
@@ -165,6 +170,7 @@ class AdaptiveMCMCExtra(NamedTuple):
165170
mcmc_extra: MCMC extra outputs.
166171
scalar_step_size: Scalar step size.
167172
vector_step_size: Vector step size.
173+
power: Power used in the SNAPER criterion.
168174
principal: Principal component.
169175
max_eigenvalue: Maximum eigenvalue.
170176
mean_trajectory_length: Mean trajectory length.
@@ -177,6 +183,7 @@ class AdaptiveMCMCExtra(NamedTuple):
177183
fun_mc.prefab.MetropolisAdjustedLangevinTrajectoriesExtra]
178184
scalar_step_size: jnp.ndarray
179185
vector_step_size: jnp.ndarray
186+
power: jnp.ndarray
180187
principal: jnp.ndarray
181188
max_eigenvalue: jnp.ndarray
182189
mean_trajectory_length: jnp.ndarray
@@ -228,6 +235,7 @@ def adaptive_mcmc_init(state: jnp.ndarray,
228235
mean=jax.random.normal(
229236
jax.random.PRNGKey(0), state.shape[1:], state.dtype)),
230237
rvar_state=fun_mc.running_variance_init(state.shape[1:], state.dtype),
238+
proj_rautocov_state=fun_mc.running_covariance_init([2], state.dtype),
231239
log_step_size_opt_state=fun_mc.adam_init(jnp.log(init_step_size)),
232240
log_trajectory_length_opt_state=fun_mc.adam_init(
233241
jnp.log(init_trajectory_length)),
@@ -248,6 +256,7 @@ def adaptive_mcmc_step(
248256
scalar_step_size: Optional[jnp.ndarray] = None,
249257
vector_step_size: Optional[jnp.ndarray] = None,
250258
mean_trajectory_length: Optional[jnp.ndarray] = None,
259+
power: Optional[jnp.ndarray] = None,
251260
principal: Optional[jnp.ndarray] = None,
252261
max_num_integrator_steps: int = 500,
253262
rvar_factor: int = 8,
@@ -260,6 +269,7 @@ def adaptive_mcmc_step(
260269
principal_mean_method: str = 'running_mean',
261270
min_preconditioning_points: int = 64,
262271
state_grad_estimator: str = 'two_dir',
272+
adapt_normalization_power: bool = False,
263273
trajectory_opt_kwargs: Mapping[str, Any] = immutabledict.immutabledict({}),
264274
step_size_opt_kwargs: Mapping[str, Any] = immutabledict.immutabledict({}),
265275
):
@@ -277,6 +287,7 @@ def adaptive_mcmc_step(
277287
vector_step_size: If not None, the fixed vector step size to use.
278288
mean_trajectory_length: If not None, the fixed mean trajectory length to
279289
use.
290+
power: Power used in the SNAPER criterion.
280291
principal: If not None, the fixed unnormalized principal component to use.
281292
max_num_integrator_steps: Maximum number of integrator steps.
282293
rvar_factor: Factor for running variance adaptation rate.
@@ -292,6 +303,8 @@ def adaptive_mcmc_step(
292303
the running mean for preconditioning:
293304
state_grad_estimator: State grad estimator to use. Can be 'one_dir' or
294305
'two_dir'.
306+
adapt_normalization_power: Whether to adapt the power used for trajectory
307+
length normalization term in the snaper criterion.
295308
trajectory_opt_kwargs: Extra arguments to the trajectory length optimizer.
296309
step_size_opt_kwargs: Extra arguments to the step size optimizer.
297310
@@ -318,6 +331,13 @@ def adaptive_mcmc_step(
318331
amcmc_state.rvar_state.num_points > min_preconditioning_points,
319332
vector_step_size, jnp.ones_like(vector_step_size))
320333

334+
if power is None:
335+
if adapt_normalization_power:
336+
power = ((1. + amcmc_state.proj_rautocov_state.covariance[0, 1] /
337+
amcmc_state.proj_rautocov_state.covariance[0, 0]) / 2.)
338+
else:
339+
power = 1.
340+
321341
if principal is None:
322342
max_eigenvalue = jnp.linalg.norm(amcmc_state.principal_rmean_state.mean)
323343
principal = amcmc_state.principal_rmean_state.mean / max_eigenvalue
@@ -426,7 +446,7 @@ def log_trajectory_length_surrogate_loss_fn(log_trajectory_length):
426446
accept_prob=jnp.exp(jnp.minimum(0., -mcmc_extra.log_accept_ratio)),
427447
trajectory_length=trajectory_length + scalar_step_size,
428448
principal=principal,
429-
power=1.,
449+
power=power,
430450
# These two expressions are a bit weird for the reverse direction...
431451
state_mean=amcmc_state.rvar_state.mean,
432452
state_mean_weight=(amcmc_state.rvar_state.num_points) /
@@ -485,6 +505,18 @@ def log_trajectory_length_surrogate_loss_fn(log_trajectory_length):
485505
principal_rmean_state = fun_mc.choose(adapt, cand_principal_rmean_state,
486506
amcmc_state.principal_rmean_state)
487507

508+
# Adjust auto-covariance of the squared projections.
509+
cand_proj_rautocov_state, _ = fun_mc.running_covariance_step(
510+
amcmc_state.proj_rautocov_state,
511+
jnp.stack([
512+
log_trajectory_length_opt_extra.loss_extra[0]['proposed_projection'],
513+
log_trajectory_length_opt_extra.loss_extra[0]['previous_projection']
514+
], -1),
515+
axis=0,
516+
window_size=jnp.maximum(1, num_chains * amcmc_state.step // rvar_factor))
517+
proj_rautocov_state = fun_mc.choose(adapt, cand_proj_rautocov_state,
518+
amcmc_state.proj_rautocov_state)
519+
488520
# =================
489521
# Iterate averaging
490522
# =================
@@ -506,6 +538,7 @@ def log_trajectory_length_surrogate_loss_fn(log_trajectory_length):
506538
amcmc_state = amcmc_state._replace(
507539
mcmc_state=mcmc_state,
508540
rvar_state=rvar_state,
541+
proj_rautocov_state=proj_rautocov_state,
509542
principal_rmean_state=principal_rmean_state,
510543
log_step_size_opt_state=log_step_size_opt_state,
511544
log_trajectory_length_opt_state=log_trajectory_length_opt_state,
@@ -518,6 +551,7 @@ def log_trajectory_length_surrogate_loss_fn(log_trajectory_length):
518551
scalar_step_size=scalar_step_size,
519552
vector_step_size=vector_step_size,
520553
principal=principal,
554+
power=power,
521555
max_eigenvalue=max_eigenvalue,
522556
damping=damping,
523557
mean_trajectory_length=mean_trajectory_length,
@@ -575,9 +609,55 @@ def compute_stats(state: jnp.ndarray, num_grads: jnp.ndarray, mean: jnp.ndarray,
575609
return res
576610

577611

612+
def get_init_x(target: gym.targets.Model,
613+
num_chains: int,
614+
num_prior_samples: int = 256,
615+
method: str = 'prior_mean') -> jnp.ndarray:
616+
"""Returns a 'good' initializer for MCMC chains.
617+
618+
Args:
619+
target: Target model.
620+
num_chains: Number of chains to return.
621+
num_prior_samples: Number of prior samples to use when computing the prior
622+
mean.
623+
method: Method to use. Can be either 'prior_mean' or 'z_zero'.
624+
625+
Returns:
626+
Initial position of the MCMC chain.
627+
"""
628+
if method == 'prior_mean':
629+
if (isinstance(target, gym.targets.VectorModel) and
630+
hasattr(target.model, 'prior_distribution')):
631+
prior = target.model.prior_distribution()
632+
try:
633+
init_point = target.structured_event_to_vector(prior.mean())
634+
except (ValueError, NotImplementedError):
635+
prior_samples = target.structured_event_to_vector(
636+
prior.sample(num_prior_samples, seed=jax.random.PRNGKey(0)))
637+
init_point = prior_samples.mean(0)
638+
elif hasattr(target, 'prior_distribution'):
639+
prior = target.prior_distribution()
640+
try:
641+
init_point = prior.mean()
642+
except (ValueError, NotImplementedError):
643+
prior_samples = prior.sample(
644+
num_prior_samples, seed=jax.random.PRNGKey(0))
645+
init_point = prior_samples.mean(0)
646+
else:
647+
b = target.default_event_space_bijector
648+
init_point = b(
649+
jnp.zeros(b.inverse_event_shape(target.event_shape), target.dtype))
650+
elif method == 'z_zero':
651+
b = target.default_event_space_bijector
652+
init_point = b(
653+
jnp.zeros(b.inverse_event_shape(target.event_shape), target.dtype))
654+
655+
return jnp.tile(init_point[jnp.newaxis], [num_chains, 1])
656+
657+
578658
@gin.configurable
579659
def run_adaptive_mcmc_on_target(
580-
target: gym.targets.VectorModel,
660+
target: gym.targets.Model,
581661
method: str,
582662
num_chains: int,
583663
init_step_size: jnp.ndarray,
@@ -604,11 +684,7 @@ def run_adaptive_mcmc_on_target(
604684
A tuple of final and traced results.
605685
"""
606686
init_z = target.default_event_space_bijector.inverse(
607-
jnp.tile(
608-
target.structured_event_to_vector(
609-
target.model.prior_distribution().sample(
610-
256, seed=jax.random.PRNGKey(0))).mean(0, keepdims=True),
611-
[num_chains, 1]))
687+
get_init_x(target, num_chains))
612688

613689
def target_log_prob_fn(x):
614690
return target.unnormalized_log_prob(x), ()
@@ -641,6 +717,7 @@ def kernel(amcmc_state, seed):
641717
'scalar_step_size': amcmc_extra.scalar_step_size,
642718
'vector_step_size': amcmc_extra.vector_step_size,
643719
'principal': amcmc_extra.principal,
720+
'power': amcmc_extra.power,
644721
'max_eigenvalue': amcmc_extra.max_eigenvalue,
645722
'mean_trajectory_length': amcmc_extra.mean_trajectory_length,
646723
'num_integrator_steps': amcmc_extra.num_integrator_steps,

0 commit comments

Comments
 (0)