@@ -120,16 +120,19 @@ def snaper_criterion(
120
120
x_center = x_center * (1 - mw ) + mx * mw
121
121
proposed_state = x - jax .lax .stop_gradient (x_center )
122
122
123
- previous_state = jnp .einsum ('d,nd->n' , principal , previous_state )
124
- proposed_state = jnp .einsum ('d,nd->n' , principal , proposed_state )
123
+ previous_projection = jnp .einsum ('d,nd->n' , principal , previous_state )
124
+ proposed_projection = jnp .einsum ('d,nd->n' , principal , proposed_state )
125
125
126
- esjd = ((previous_state ** 2 - proposed_state ** 2 )** 2 )
126
+ esjd = ((previous_projection ** 2 - proposed_projection ** 2 )** 2 )
127
127
128
128
esjd = jnp .where (accept_prob > 1e-4 , esjd , 0. )
129
129
accept_prob = accept_prob / jnp .sum (accept_prob + 1e-20 )
130
130
esjd = esjd * accept_prob
131
131
132
- return esjd .mean () / trajectory_length ** power , ()
132
+ return esjd .mean () / trajectory_length ** power , {
133
+ 'previous_projection' : previous_projection ,
134
+ 'proposed_projection' : proposed_projection ,
135
+ }
133
136
134
137
135
138
class AdaptiveMCMCState (NamedTuple ):
@@ -138,6 +141,7 @@ class AdaptiveMCMCState(NamedTuple):
138
141
Attributes:
139
142
mcmc_state: MCMC state.
140
143
rvar_state: Running variance of the state.
144
+ proj_rautocov_state: Running lag-1 auto covariance of the projections.
141
145
principal_rmean_state: Running mean of the unnormalized principal
142
146
components.
143
147
log_step_size_opt_state: Optimizer state for the log step size.
@@ -150,6 +154,7 @@ class AdaptiveMCMCState(NamedTuple):
150
154
mcmc_state : Union [fun_mc .HamiltonianMonteCarloState ,
151
155
fun_mc .prefab .MetropolisAdjustedLangevinTrajectoriesState ]
152
156
rvar_state : fun_mc .RunningVarianceState
157
+ proj_rautocov_state : fun_mc .RunningCovarianceState
153
158
principal_rmean_state : fun_mc .RunningMeanState
154
159
log_step_size_opt_state : fun_mc .AdamState
155
160
log_trajectory_length_opt_state : fun_mc .AdamState
@@ -165,6 +170,7 @@ class AdaptiveMCMCExtra(NamedTuple):
165
170
mcmc_extra: MCMC extra outputs.
166
171
scalar_step_size: Scalar step size.
167
172
vector_step_size: Vector step size.
173
+ power: Power used in the SNAPER criterion.
168
174
principal: Principal component.
169
175
max_eigenvalue: Maximum eigenvalue.
170
176
mean_trajectory_length: Mean trajectory length.
@@ -177,6 +183,7 @@ class AdaptiveMCMCExtra(NamedTuple):
177
183
fun_mc .prefab .MetropolisAdjustedLangevinTrajectoriesExtra ]
178
184
scalar_step_size : jnp .ndarray
179
185
vector_step_size : jnp .ndarray
186
+ power : jnp .ndarray
180
187
principal : jnp .ndarray
181
188
max_eigenvalue : jnp .ndarray
182
189
mean_trajectory_length : jnp .ndarray
@@ -228,6 +235,7 @@ def adaptive_mcmc_init(state: jnp.ndarray,
228
235
mean = jax .random .normal (
229
236
jax .random .PRNGKey (0 ), state .shape [1 :], state .dtype )),
230
237
rvar_state = fun_mc .running_variance_init (state .shape [1 :], state .dtype ),
238
+ proj_rautocov_state = fun_mc .running_covariance_init ([2 ], state .dtype ),
231
239
log_step_size_opt_state = fun_mc .adam_init (jnp .log (init_step_size )),
232
240
log_trajectory_length_opt_state = fun_mc .adam_init (
233
241
jnp .log (init_trajectory_length )),
@@ -248,6 +256,7 @@ def adaptive_mcmc_step(
248
256
scalar_step_size : Optional [jnp .ndarray ] = None ,
249
257
vector_step_size : Optional [jnp .ndarray ] = None ,
250
258
mean_trajectory_length : Optional [jnp .ndarray ] = None ,
259
+ power : Optional [jnp .ndarray ] = None ,
251
260
principal : Optional [jnp .ndarray ] = None ,
252
261
max_num_integrator_steps : int = 500 ,
253
262
rvar_factor : int = 8 ,
@@ -260,6 +269,7 @@ def adaptive_mcmc_step(
260
269
principal_mean_method : str = 'running_mean' ,
261
270
min_preconditioning_points : int = 64 ,
262
271
state_grad_estimator : str = 'two_dir' ,
272
+ adapt_normalization_power : bool = False ,
263
273
trajectory_opt_kwargs : Mapping [str , Any ] = immutabledict .immutabledict ({}),
264
274
step_size_opt_kwargs : Mapping [str , Any ] = immutabledict .immutabledict ({}),
265
275
):
@@ -277,6 +287,7 @@ def adaptive_mcmc_step(
277
287
vector_step_size: If not None, the fixed vector step size to use.
278
288
mean_trajectory_length: If not None, the fixed mean trajectory length to
279
289
use.
290
+ power: Power used in the SNAPER criterion.
280
291
principal: If not None, the fixed unnormalized principal component to use.
281
292
max_num_integrator_steps: Maximum number of integrator steps.
282
293
rvar_factor: Factor for running variance adaptation rate.
@@ -292,6 +303,8 @@ def adaptive_mcmc_step(
292
303
the running mean for preconditioning:
293
304
state_grad_estimator: State grad estimator to use. Can be 'one_dir' or
294
305
'two_dir'.
306
+ adapt_normalization_power: Whether to adapt the power used for trajectory
307
+ length normalization term in the snaper criterion.
295
308
trajectory_opt_kwargs: Extra arguments to the trajectory length optimizer.
296
309
step_size_opt_kwargs: Extra arguments to the step size optimizer.
297
310
@@ -318,6 +331,13 @@ def adaptive_mcmc_step(
318
331
amcmc_state .rvar_state .num_points > min_preconditioning_points ,
319
332
vector_step_size , jnp .ones_like (vector_step_size ))
320
333
334
+ if power is None :
335
+ if adapt_normalization_power :
336
+ power = ((1. + amcmc_state .proj_rautocov_state .covariance [0 , 1 ] /
337
+ amcmc_state .proj_rautocov_state .covariance [0 , 0 ]) / 2. )
338
+ else :
339
+ power = 1.
340
+
321
341
if principal is None :
322
342
max_eigenvalue = jnp .linalg .norm (amcmc_state .principal_rmean_state .mean )
323
343
principal = amcmc_state .principal_rmean_state .mean / max_eigenvalue
@@ -426,7 +446,7 @@ def log_trajectory_length_surrogate_loss_fn(log_trajectory_length):
426
446
accept_prob = jnp .exp (jnp .minimum (0. , - mcmc_extra .log_accept_ratio )),
427
447
trajectory_length = trajectory_length + scalar_step_size ,
428
448
principal = principal ,
429
- power = 1. ,
449
+ power = power ,
430
450
# These two expressions are a bit weird for the reverse direction...
431
451
state_mean = amcmc_state .rvar_state .mean ,
432
452
state_mean_weight = (amcmc_state .rvar_state .num_points ) /
@@ -485,6 +505,18 @@ def log_trajectory_length_surrogate_loss_fn(log_trajectory_length):
485
505
principal_rmean_state = fun_mc .choose (adapt , cand_principal_rmean_state ,
486
506
amcmc_state .principal_rmean_state )
487
507
508
+ # Adjust auto-covariance of the squared projections.
509
+ cand_proj_rautocov_state , _ = fun_mc .running_covariance_step (
510
+ amcmc_state .proj_rautocov_state ,
511
+ jnp .stack ([
512
+ log_trajectory_length_opt_extra .loss_extra [0 ]['proposed_projection' ],
513
+ log_trajectory_length_opt_extra .loss_extra [0 ]['previous_projection' ]
514
+ ], - 1 ),
515
+ axis = 0 ,
516
+ window_size = jnp .maximum (1 , num_chains * amcmc_state .step // rvar_factor ))
517
+ proj_rautocov_state = fun_mc .choose (adapt , cand_proj_rautocov_state ,
518
+ amcmc_state .proj_rautocov_state )
519
+
488
520
# =================
489
521
# Iterate averaging
490
522
# =================
@@ -506,6 +538,7 @@ def log_trajectory_length_surrogate_loss_fn(log_trajectory_length):
506
538
amcmc_state = amcmc_state ._replace (
507
539
mcmc_state = mcmc_state ,
508
540
rvar_state = rvar_state ,
541
+ proj_rautocov_state = proj_rautocov_state ,
509
542
principal_rmean_state = principal_rmean_state ,
510
543
log_step_size_opt_state = log_step_size_opt_state ,
511
544
log_trajectory_length_opt_state = log_trajectory_length_opt_state ,
@@ -518,6 +551,7 @@ def log_trajectory_length_surrogate_loss_fn(log_trajectory_length):
518
551
scalar_step_size = scalar_step_size ,
519
552
vector_step_size = vector_step_size ,
520
553
principal = principal ,
554
+ power = power ,
521
555
max_eigenvalue = max_eigenvalue ,
522
556
damping = damping ,
523
557
mean_trajectory_length = mean_trajectory_length ,
@@ -575,9 +609,55 @@ def compute_stats(state: jnp.ndarray, num_grads: jnp.ndarray, mean: jnp.ndarray,
575
609
return res
576
610
577
611
612
+ def get_init_x (target : gym .targets .Model ,
613
+ num_chains : int ,
614
+ num_prior_samples : int = 256 ,
615
+ method : str = 'prior_mean' ) -> jnp .ndarray :
616
+ """Returns a 'good' initializer for MCMC chains.
617
+
618
+ Args:
619
+ target: Target model.
620
+ num_chains: Number of chains to return.
621
+ num_prior_samples: Number of prior samples to use when computing the prior
622
+ mean.
623
+ method: Method to use. Can be either 'prior_mean' or 'z_zero'.
624
+
625
+ Returns:
626
+ Initial position of the MCMC chain.
627
+ """
628
+ if method == 'prior_mean' :
629
+ if (isinstance (target , gym .targets .VectorModel ) and
630
+ hasattr (target .model , 'prior_distribution' )):
631
+ prior = target .model .prior_distribution ()
632
+ try :
633
+ init_point = target .structured_event_to_vector (prior .mean ())
634
+ except (ValueError , NotImplementedError ):
635
+ prior_samples = target .structured_event_to_vector (
636
+ prior .sample (num_prior_samples , seed = jax .random .PRNGKey (0 )))
637
+ init_point = prior_samples .mean (0 )
638
+ elif hasattr (target , 'prior_distribution' ):
639
+ prior = target .prior_distribution ()
640
+ try :
641
+ init_point = prior .mean ()
642
+ except (ValueError , NotImplementedError ):
643
+ prior_samples = prior .sample (
644
+ num_prior_samples , seed = jax .random .PRNGKey (0 ))
645
+ init_point = prior_samples .mean (0 )
646
+ else :
647
+ b = target .default_event_space_bijector
648
+ init_point = b (
649
+ jnp .zeros (b .inverse_event_shape (target .event_shape ), target .dtype ))
650
+ elif method == 'z_zero' :
651
+ b = target .default_event_space_bijector
652
+ init_point = b (
653
+ jnp .zeros (b .inverse_event_shape (target .event_shape ), target .dtype ))
654
+
655
+ return jnp .tile (init_point [jnp .newaxis ], [num_chains , 1 ])
656
+
657
+
578
658
@gin .configurable
579
659
def run_adaptive_mcmc_on_target (
580
- target : gym .targets .VectorModel ,
660
+ target : gym .targets .Model ,
581
661
method : str ,
582
662
num_chains : int ,
583
663
init_step_size : jnp .ndarray ,
@@ -604,11 +684,7 @@ def run_adaptive_mcmc_on_target(
604
684
A tuple of final and traced results.
605
685
"""
606
686
init_z = target .default_event_space_bijector .inverse (
607
- jnp .tile (
608
- target .structured_event_to_vector (
609
- target .model .prior_distribution ().sample (
610
- 256 , seed = jax .random .PRNGKey (0 ))).mean (0 , keepdims = True ),
611
- [num_chains , 1 ]))
687
+ get_init_x (target , num_chains ))
612
688
613
689
def target_log_prob_fn (x ):
614
690
return target .unnormalized_log_prob (x ), ()
@@ -641,6 +717,7 @@ def kernel(amcmc_state, seed):
641
717
'scalar_step_size' : amcmc_extra .scalar_step_size ,
642
718
'vector_step_size' : amcmc_extra .vector_step_size ,
643
719
'principal' : amcmc_extra .principal ,
720
+ 'power' : amcmc_extra .power ,
644
721
'max_eigenvalue' : amcmc_extra .max_eigenvalue ,
645
722
'mean_trajectory_length' : amcmc_extra .mean_trajectory_length ,
646
723
'num_integrator_steps' : amcmc_extra .num_integrator_steps ,
0 commit comments