Skip to content

Commit 3b384c4

Browse files
davmretensorflower-gardener
authored andcommitted
Support MarkovChain structures in ASVI.
PiperOrigin-RevId: 380650245
1 parent 6e9d1c0 commit 3b384c4

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

tensorflow_probability/python/experimental/vi/automatic_structured_vi.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from tensorflow_probability.python.distributions import independent
4040
from tensorflow_probability.python.distributions import joint_distribution_auto_batched
4141
from tensorflow_probability.python.distributions import joint_distribution_coroutine
42+
from tensorflow_probability.python.distributions import markov_chain
4243
from tensorflow_probability.python.distributions import sample
4344
from tensorflow_probability.python.distributions import transformed_distribution
4445
from tensorflow_probability.python.distributions import truncated_normal
@@ -297,6 +298,13 @@ def _asvi_surrogate_for_distribution(dist,
297298
dist = _set_name(_as_substituted_distribution(dist), name=_get_name(dist))
298299

299300
# Handle wrapper ("meta") distributions.
301+
if isinstance(dist, markov_chain.MarkovChain):
302+
return _asvi_surrogate_for_markov_chain(
303+
dist=dist,
304+
variables=variables,
305+
base_distribution_surrogate_fn=base_distribution_surrogate_fn,
306+
sample_shape=sample_shape,
307+
seed=seed)
300308
if isinstance(dist, sample.Sample):
301309
dist_sample_shape = distribution_util.expand_to_vector(dist.sample_shape)
302310
nested_surrogate, variables = build_nested_surrogate( # pylint: disable=redundant-keyword-arg
@@ -417,6 +425,58 @@ def posterior_generator(seed=seed):
417425
return surrogate_posterior, variables
418426

419427

428+
def _asvi_surrogate_for_markov_chain(dist,
429+
base_distribution_surrogate_fn,
430+
sample_shape=None,
431+
variables=None,
432+
seed=None):
433+
"""Builds a structured surrogate posterior for a Markov chain."""
434+
prior_seed, transition_seed = samplers.split_seed(seed, 2)
435+
if variables is None:
436+
prior_variables, transition_variables = None, None
437+
else:
438+
prior_variables, transition_variables = variables
439+
440+
surrogate_prior, prior_variables = _asvi_surrogate_for_distribution(
441+
dist.initial_state_prior,
442+
base_distribution_surrogate_fn=base_distribution_surrogate_fn,
443+
variables=prior_variables,
444+
seed=prior_seed)
445+
446+
if transition_variables is None:
447+
# Construct variables for all chain steps in a single call. These will have
448+
# an initial dimension of size `num_steps - 1`, which we can gather from
449+
# as the chain runs.
450+
all_steps = tf.range(dist.num_steps - 1)
451+
batch_state = dist.initial_state_prior.sample(dist.num_steps - 1)
452+
_, transition_variables = _asvi_surrogate_for_distribution(
453+
dist.transition_fn(all_steps, batch_state),
454+
base_distribution_surrogate_fn=base_distribution_surrogate_fn,
455+
variables=None,
456+
sample_shape=sample_shape,
457+
seed=transition_seed)
458+
459+
def surrogate_transition_fn(step, state):
460+
surrogate_new_dist, _ = _asvi_surrogate_for_distribution(
461+
dist.transition_fn(step, state),
462+
base_distribution_surrogate_fn=base_distribution_surrogate_fn,
463+
variables=tf.nest.map_structure(
464+
# Gather parameters for this specific step of the chain.
465+
lambda v: tf.gather(v, step, axis=0), transition_variables),
466+
sample_shape=sample_shape,
467+
seed=transition_seed)
468+
return surrogate_new_dist
469+
470+
chain_surrogate = markov_chain.MarkovChain(
471+
initial_state_prior=surrogate_prior,
472+
transition_fn=surrogate_transition_fn,
473+
num_steps=dist.num_steps,
474+
validate_args=dist.validate_args,
475+
name=_get_name(dist))
476+
477+
return chain_surrogate, [prior_variables, transition_variables]
478+
479+
420480
# TODO(davmre): consider breaking the mean field case into a separate method.
421481
def _asvi_convex_update_for_base_distribution(dist,
422482
mean_field,

tensorflow_probability/python/experimental/vi/automatic_structured_vi_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,38 @@ def nested_model():
286286
return tfd.JointDistributionCoroutineAutoBatched(nested_model)
287287

288288

289+
@test_util.test_all_tf_execution_regimes
290+
class ASVISurrogatePosteriorTestMarkovChain(test_util.TestCase,
291+
_TrainableASVISurrogate):
292+
293+
def _expected_num_trainable_variables(self, _):
294+
return 16
295+
296+
def make_prior_dist(self):
297+
num_timesteps = 10
298+
def stochastic_volatility_prior_fn():
299+
"""Generative process for a stochastic volatility model."""
300+
persistence_of_volatility = 0.9
301+
mean_log_volatility = yield tfd.Cauchy(
302+
loc=0., scale=5., name='mean_log_volatility')
303+
white_noise_shock_scale = yield tfd.HalfCauchy(
304+
loc=0., scale=2., name='white_noise_shock_scale')
305+
_ = yield tfd.MarkovChain(
306+
initial_state_prior=tfd.Normal(
307+
loc=mean_log_volatility,
308+
scale=white_noise_shock_scale / tf.math.sqrt(
309+
tf.ones([]) - persistence_of_volatility**2)),
310+
transition_fn=lambda _, x_t: tfd.Normal( # pylint: disable=g-long-lambda
311+
loc=persistence_of_volatility * (
312+
x_t - mean_log_volatility) + mean_log_volatility,
313+
scale=white_noise_shock_scale),
314+
num_steps=num_timesteps,
315+
name='log_volatility')
316+
317+
return tfd.JointDistributionCoroutineAutoBatched(
318+
stochastic_volatility_prior_fn)
319+
320+
289321
@test_util.test_all_tf_execution_regimes
290322
class TestASVIDistributionSubstitution(test_util.TestCase):
291323

0 commit comments

Comments
 (0)