Skip to content

Commit 838888a

Browse files
davmretensorflower-gardener
authored andcommitted
Update monte_carlo_variational_loss to use sample_and_log_prob.
This should be no less robust than the current implementation, since the naive fallback for `sample_and_log_prob` just calls sample and log_prob separately, which is what we were already doing here. PiperOrigin-RevId: 375619399
1 parent e1d930f commit 838888a

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

tensorflow_probability/python/vi/csiszar_divergence.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import functools
22+
2123
# Dependency imports
2224
import numpy as np
2325

@@ -929,17 +931,6 @@ def monte_carlo_variational_loss(target_log_prob_fn,
929931
930932
"""
931933
with tf.name_scope(name or 'monte_carlo_variational_loss'):
932-
933-
def divergence_fn(q_samples):
934-
target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples)
935-
return discrepancy_fn(
936-
target_log_prob - surrogate_posterior.log_prob(
937-
q_samples))
938-
939-
# If Q is joint, drawing samples forces it to build its components. It's
940-
# important to do this *before* checking its reparameterization type.
941-
q_samples = surrogate_posterior.sample(sample_size, seed=seed)
942-
943934
reparameterization_types = tf.nest.flatten(
944935
surrogate_posterior.reparameterization_type)
945936
if use_reparameterization is None:
@@ -959,6 +950,22 @@ def divergence_fn(q_samples):
959950
raise TypeError('`target_log_prob_fn` must be a Python `callable`'
960951
'function.')
961952

953+
def divergence_fn(q_samples, q_lp=None):
954+
target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples)
955+
if q_lp is None:
956+
q_lp = surrogate_posterior.log_prob(q_samples)
957+
return discrepancy_fn(target_log_prob - q_lp)
958+
959+
if use_reparameterization:
960+
# Attempt to avoid bijector inverses by computing the surrogate log prob
961+
# during the forward sampling pass.
962+
q_samples, q_lp = surrogate_posterior.experimental_sample_and_log_prob(
963+
sample_size, seed=seed)
964+
divergence_fn = functools.partial(divergence_fn, q_lp=q_lp)
965+
else:
966+
# Score fn objective requires explicit gradients of `log_prob`.
967+
q_samples = surrogate_posterior.sample(sample_size, seed=seed)
968+
962969
return monte_carlo.expectation(
963970
f=divergence_fn,
964971
samples=q_samples,

0 commit comments

Comments
 (0)