Skip to content

Commit edb9c5f

Browse files
davmretensorflower-gardener
authored andcommitted
Implement sample_and_log_prob for joint distributions.
PiperOrigin-RevId: 375534806
1 parent 355f190 commit edb9c5f

File tree

4 files changed

+161
-26
lines changed

4 files changed

+161
-26
lines changed

tensorflow_probability/python/distributions/joint_distribution.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,15 @@ def trace_values_only(dist, sample_shape, seed, value=None):
9898
return ValueWithTrace(value=value, traced=value)
9999

100100

101+
def trace_values_and_log_probs(dist, sample_shape, seed, value=None):
102+
"""Draws a sample, and traces both the sampled value and its log density."""
103+
if value is None:
104+
value, lp = dist.experimental_sample_and_log_prob(sample_shape, seed=seed)
105+
else:
106+
lp = dist.log_prob(value)
107+
return ValueWithTrace(value=value, traced=(value, lp))
108+
109+
101110
CALLING_CONVENTION_DESCRIPTION = """
102111
The measure methods of `JointDistribution` (`log_prob`, `prob`, etc.)
103112
can be called either by passing a single structure of tensors or by using
@@ -570,6 +579,19 @@ def _sample_n(self, sample_shape, seed, value=None):
570579

571580
return self._model_unflatten(xs)
572581

582+
# TODO(b/189122177): Implement _sample_and_log_prob for distributed JDs.
583+
def _sample_and_log_prob(self, sample_shape, seed, value=None, **kwargs):
584+
xs, lps = zip(
585+
*self._call_execute_model(
586+
sample_shape,
587+
seed=seed,
588+
value=self._resolve_value(value=value,
589+
allow_partially_specified=True,
590+
**kwargs),
591+
sample_and_trace_fn=trace_values_and_log_probs))
592+
return (self._model_unflatten(xs),
593+
sum(maybe_check_wont_broadcast(lps, self.validate_args)))
594+
573595
def _map_measure_over_dists(self, attr, value):
574596
if any(x is None for x in tf.nest.flatten(value)):
575597
raise ValueError('No `value` part can be `None`; saw: {}.'.format(value))

tensorflow_probability/python/distributions/joint_distribution_auto_batched_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,66 @@ def coroutine_model():
334334

335335
self.assertAllClose(*self.evaluate([log_prob, expected_log_prob]))
336336

337+
@parameterized.named_parameters(
338+
{'testcase_name': 'coroutine',
339+
'jd_class': tfd.JointDistributionCoroutineAutoBatched},
340+
{'testcase_name': 'sequential',
341+
'jd_class': tfd.JointDistributionSequentialAutoBatched},
342+
{'testcase_name': 'named',
343+
'jd_class': tfd.JointDistributionNamedAutoBatched})
344+
def test_sample_and_log_prob(self, jd_class):
345+
346+
# Define a bijector to detect if/when `inverse` is called.
347+
inverted_values = []
348+
349+
class InverseTracingExp(tfb.Exp):
350+
351+
def _inverse(self, y):
352+
inverted_values.append(y)
353+
return tf.math.log(y)
354+
355+
models = {}
356+
357+
def coroutine_model():
358+
g = yield InverseTracingExp()(tfd.Normal(0., 1.), name='g')
359+
df = yield tfd.Exponential(1., name='df')
360+
loc = yield tfd.Sample(tfd.Normal(0, g), 20, name='loc')
361+
yield tfd.StudentT(df, loc, 1, name='x')
362+
models[tfd.JointDistributionCoroutineAutoBatched] = coroutine_model
363+
364+
models[tfd.JointDistributionSequentialAutoBatched] = [
365+
InverseTracingExp()(tfd.Normal(0., 1.), name='g'),
366+
tfd.Exponential(1., name='df'),
367+
lambda _, g: tfd.Sample(tfd.Normal(0, g), 20, name='loc'),
368+
lambda loc, df: tfd.StudentT(df, loc, 1, name='x')
369+
]
370+
371+
models[tfd.JointDistributionNamedAutoBatched] = collections.OrderedDict((
372+
('g', InverseTracingExp()(tfd.Normal(0., 1.))),
373+
('df', tfd.Exponential(1.)),
374+
('loc', lambda g: tfd.Sample(tfd.Normal(0, g), 20)),
375+
('x', lambda loc, df: tfd.StudentT(df, loc, 1))))
376+
377+
joint = jd_class(models[jd_class], validate_args=True)
378+
379+
for sample_shape in ([], [5]):
380+
inverted_values.clear()
381+
x1, lp1 = self.evaluate(
382+
joint.experimental_sample_and_log_prob(
383+
sample_shape,
384+
seed=test_util.test_seed(sampler_type='seedless'),
385+
df=2.7)) # Check that kwargs are supported.
386+
x2 = self.evaluate(
387+
joint.sample(sample_shape,
388+
seed=test_util.test_seed(sampler_type='seedless'),
389+
df=2.7))
390+
self.assertAllCloseNested(x1, x2)
391+
392+
self.assertLen(inverted_values, 0)
393+
lp2 = joint.log_prob(x1)
394+
self.assertLen(inverted_values, 1)
395+
self.assertAllClose(lp1, lp2)
396+
337397
def test_sample_with_batch_value(self):
338398
@tfd.JointDistributionCoroutineAutoBatched
339399
def dist():

tensorflow_probability/python/distributions/joint_distribution_coroutine_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,45 @@ def dist():
334334

335335
self.assertAllClose(*self.evaluate([log_prob, expected_log_prob]))
336336

337+
def test_sample_and_log_prob(self):
338+
339+
# Define a bijector to detect if/when `inverse` is called.
340+
inverted_values = []
341+
342+
class InverseTracingExp(tfb.Exp):
343+
344+
def _inverse(self, y):
345+
inverted_values.append(y)
346+
return tf.math.log(y)
347+
348+
def coroutine_model():
349+
g = yield Root(InverseTracingExp()(tfd.Normal(0., 1.), name='g'))
350+
df = yield Root(tfd.Exponential(1., name='df'))
351+
loc = yield tfd.Sample(tfd.Normal(0, g), 20, name='loc')
352+
yield tfd.Independent(tfd.StudentT(df[..., tf.newaxis], loc, 1, name='x'),
353+
reinterpreted_batch_ndims=1)
354+
355+
joint = tfd.JointDistributionCoroutine(coroutine_model, validate_args=True)
356+
357+
for sample_shape in ([], [5]):
358+
inverted_values.clear()
359+
x1, lp1 = self.evaluate(
360+
joint.experimental_sample_and_log_prob(
361+
sample_shape,
362+
seed=test_util.test_seed(sampler_type='seedless'),
363+
df=2.7 * tf.ones(sample_shape) # Check that kwargs are supported.
364+
))
365+
x2 = self.evaluate(
366+
joint.sample(sample_shape,
367+
seed=test_util.test_seed(sampler_type='seedless'),
368+
df=2.7 * tf.ones(sample_shape)))
369+
self.assertAllCloseNested(x1, x2)
370+
371+
self.assertLen(inverted_values, 0)
372+
lp2 = joint.log_prob(x1)
373+
self.assertLen(inverted_values, 1)
374+
self.assertAllClose(lp1, lp2)
375+
337376
def test_detect_missing_root(self):
338377
if not tf.executing_eagerly(): return
339378
# The joint distribution specified below is intended to

tensorflow_probability/python/distributions/joint_distribution_sample_path_mixin.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import tensorflow.compat.v2 as tf
2424

2525
from tensorflow_probability.python import math as tfp_math
26+
from tensorflow_probability.python.distributions import joint_distribution as jd_lib
2627
from tensorflow_probability.python.internal import assert_util
2728
from tensorflow_probability.python.internal import prefer_static
2829

@@ -123,9 +124,7 @@ def event_shape_tensor(self, sample_shape=(), name='event_shape_tensor'):
123124
d.event_shape_tensor()))
124125
return self._model_unflatten(component_shapes)
125126

126-
def _map_and_reduce_measure_over_dists(self, attr, reduce_fn, value):
127-
"""Reduces all non-batch dimensions of the provided measure."""
128-
xs = list(self._map_measure_over_dists(attr, value))
127+
def _reduce_measure_over_dists(self, xs, reduce_fn):
129128
num_trailing_batch_dims_treated_as_event = [
130129
prefer_static.rank_from_shape(
131130
d.batch_shape_tensor()) - self._batch_ndims
@@ -145,14 +144,31 @@ def _maybe_check_batch_shape(self):
145144
parts[0], s, message='Component batch shapes are inconsistent.'))
146145
return assertions
147146

148-
def _log_prob(self, value):
147+
def _reduce_log_probs_over_dists(self, lps):
149148
if self._experimental_use_kahan_sum:
150-
xs = self._map_and_reduce_measure_over_dists(
151-
'log_prob', tfp_math.reduce_kahan_sum, value)
152-
return sum(xs).total
153-
xs = self._map_and_reduce_measure_over_dists(
154-
'log_prob', tf.reduce_sum, value)
155-
return sum(xs)
149+
return sum(jd_lib.maybe_check_wont_broadcast(
150+
self._reduce_measure_over_dists(
151+
lps, reduce_fn=tfp_math.reduce_kahan_sum),
152+
self.validate_args)).total
153+
else:
154+
return sum(jd_lib.maybe_check_wont_broadcast(
155+
self._reduce_measure_over_dists(lps, reduce_fn=tf.reduce_sum),
156+
self.validate_args))
157+
158+
def _sample_and_log_prob(self, sample_shape, seed, value=None, **kwargs):
159+
xs, lps = zip(
160+
*self._call_execute_model(
161+
sample_shape,
162+
seed=seed,
163+
value=self._resolve_value(value=value,
164+
allow_partially_specified=True,
165+
**kwargs),
166+
sample_and_trace_fn=jd_lib.trace_values_and_log_probs))
167+
return self._model_unflatten(xs), self._reduce_log_probs_over_dists(lps)
168+
169+
def _log_prob(self, value):
170+
return self._reduce_log_probs_over_dists(
171+
self._map_measure_over_dists('log_prob', value))
156172

157173
def log_prob_parts(self, value, name='log_prob_parts'):
158174
"""Log probability density/mass function, part-wise.
@@ -172,18 +188,14 @@ def log_prob_parts(self, value, name='log_prob_parts'):
172188
sum_fn = tf.reduce_sum
173189
if self._experimental_use_kahan_sum:
174190
sum_fn = lambda x, axis: tfp_math.reduce_kahan_sum(x, axis=axis).total
175-
xs = self._map_and_reduce_measure_over_dists(
176-
'log_prob', sum_fn, value)
177-
return self._model_unflatten(xs)
191+
return self._model_unflatten(
192+
self._reduce_measure_over_dists(
193+
self._map_measure_over_dists('log_prob', value),
194+
sum_fn))
178195

179196
def _unnormalized_log_prob(self, value):
180-
if self._experimental_use_kahan_sum:
181-
xs = self._map_and_reduce_measure_over_dists(
182-
'unnormalized_log_prob', tfp_math.reduce_kahan_sum, value)
183-
return sum(xs).total
184-
xs = self._map_and_reduce_measure_over_dists(
185-
'unnormalized_log_prob', tf.reduce_sum, value)
186-
return sum(xs)
197+
return self._reduce_log_probs_over_dists(
198+
self._map_measure_over_dists('unnormalized_log_prob', value))
187199

188200
def unnormalized_log_prob_parts(self, value, name=None):
189201
"""Unnormalized log probability density/mass function, part-wise.
@@ -203,9 +215,10 @@ def unnormalized_log_prob_parts(self, value, name=None):
203215
sum_fn = tf.reduce_sum
204216
if self._experimental_use_kahan_sum:
205217
sum_fn = lambda x, axis: tfp_math.reduce_kahan_sum(x, axis=axis).total
206-
xs = self._map_and_reduce_measure_over_dists(
207-
'unnormalized_log_prob', sum_fn, value)
208-
return self._model_unflatten(xs)
218+
return self._model_unflatten(
219+
self._reduce_measure_over_dists(
220+
self._map_measure_over_dists('unnormalized_log_prob', value),
221+
sum_fn))
209222

210223
def prob_parts(self, value, name='prob_parts'):
211224
"""Log probability density/mass function.
@@ -221,9 +234,10 @@ def prob_parts(self, value, name='prob_parts'):
221234
each `distribution_fn` evaluated at each corresponding `value`.
222235
"""
223236
with self._name_and_control_scope(name):
224-
xs = self._map_and_reduce_measure_over_dists(
225-
'prob', tf.reduce_prod, value)
226-
return self._model_unflatten(xs)
237+
return self._model_unflatten(
238+
self._reduce_measure_over_dists(
239+
self._map_measure_over_dists('prob', value),
240+
tf.reduce_prod))
227241

228242
def is_scalar_batch(self, name='is_scalar_batch'):
229243
"""Indicates that `batch_shape == []`.

0 commit comments

Comments
 (0)