|
39 | 39 | from tensorflow_probability.python.distributions import independent
|
40 | 40 | from tensorflow_probability.python.distributions import joint_distribution_auto_batched
|
41 | 41 | from tensorflow_probability.python.distributions import joint_distribution_coroutine
|
| 42 | +from tensorflow_probability.python.distributions import markov_chain |
42 | 43 | from tensorflow_probability.python.distributions import sample
|
43 | 44 | from tensorflow_probability.python.distributions import transformed_distribution
|
44 | 45 | from tensorflow_probability.python.distributions import truncated_normal
|
@@ -297,6 +298,13 @@ def _asvi_surrogate_for_distribution(dist,
|
297 | 298 | dist = _set_name(_as_substituted_distribution(dist), name=_get_name(dist))
|
298 | 299 |
|
299 | 300 | # Handle wrapper ("meta") distributions.
|
| 301 | + if isinstance(dist, markov_chain.MarkovChain): |
| 302 | + return _asvi_surrogate_for_markov_chain( |
| 303 | + dist=dist, |
| 304 | + variables=variables, |
| 305 | + base_distribution_surrogate_fn=base_distribution_surrogate_fn, |
| 306 | + sample_shape=sample_shape, |
| 307 | + seed=seed) |
300 | 308 | if isinstance(dist, sample.Sample):
|
301 | 309 | dist_sample_shape = distribution_util.expand_to_vector(dist.sample_shape)
|
302 | 310 | nested_surrogate, variables = build_nested_surrogate( # pylint: disable=redundant-keyword-arg
|
@@ -417,6 +425,58 @@ def posterior_generator(seed=seed):
|
417 | 425 | return surrogate_posterior, variables
|
418 | 426 |
|
419 | 427 |
|
| 428 | +def _asvi_surrogate_for_markov_chain(dist, |
| 429 | + base_distribution_surrogate_fn, |
| 430 | + sample_shape=None, |
| 431 | + variables=None, |
| 432 | + seed=None): |
| 433 | + """Builds a structured surrogate posterior for a Markov chain.""" |
| 434 | + prior_seed, transition_seed = samplers.split_seed(seed, 2) |
| 435 | + if variables is None: |
| 436 | + prior_variables, transition_variables = None, None |
| 437 | + else: |
| 438 | + prior_variables, transition_variables = variables |
| 439 | + |
| 440 | + surrogate_prior, prior_variables = _asvi_surrogate_for_distribution( |
| 441 | + dist.initial_state_prior, |
| 442 | + base_distribution_surrogate_fn=base_distribution_surrogate_fn, |
| 443 | + variables=prior_variables, |
| 444 | + seed=prior_seed) |
| 445 | + |
| 446 | + if transition_variables is None: |
| 447 | + # Construct variables for all chain steps in a single call. These will have |
| 448 | + # an initial dimension of size `num_steps - 1`, which we can gather from |
| 449 | + # as the chain runs. |
| 450 | + all_steps = tf.range(dist.num_steps - 1) |
| 451 | + batch_state = dist.initial_state_prior.sample(dist.num_steps - 1) |
| 452 | + _, transition_variables = _asvi_surrogate_for_distribution( |
| 453 | + dist.transition_fn(all_steps, batch_state), |
| 454 | + base_distribution_surrogate_fn=base_distribution_surrogate_fn, |
| 455 | + variables=None, |
| 456 | + sample_shape=sample_shape, |
| 457 | + seed=transition_seed) |
| 458 | + |
| 459 | + def surrogate_transition_fn(step, state): |
| 460 | + surrogate_new_dist, _ = _asvi_surrogate_for_distribution( |
| 461 | + dist.transition_fn(step, state), |
| 462 | + base_distribution_surrogate_fn=base_distribution_surrogate_fn, |
| 463 | + variables=tf.nest.map_structure( |
| 464 | + # Gather parameters for this specific step of the chain. |
| 465 | + lambda v: tf.gather(v, step, axis=0), transition_variables), |
| 466 | + sample_shape=sample_shape, |
| 467 | + seed=transition_seed) |
| 468 | + return surrogate_new_dist |
| 469 | + |
| 470 | + chain_surrogate = markov_chain.MarkovChain( |
| 471 | + initial_state_prior=surrogate_prior, |
| 472 | + transition_fn=surrogate_transition_fn, |
| 473 | + num_steps=dist.num_steps, |
| 474 | + validate_args=dist.validate_args, |
| 475 | + name=_get_name(dist)) |
| 476 | + |
| 477 | + return chain_surrogate, [prior_variables, transition_variables] |
| 478 | + |
| 479 | + |
420 | 480 | # TODO(davmre): consider breaking the mean field case into a separate method.
|
421 | 481 | def _asvi_convex_update_for_base_distribution(dist,
|
422 | 482 | mean_field,
|
|
0 commit comments