Skip to content

Commit fd8be3c

Browse files
davmretensorflower-gardener
authored andcommitted
Add an experimental_sample_and_log_prob method to TFP Distributions.
This is an alternative to bijector caching, intended to be useful (at least) in cases where the cache would otherwise miss. For example, autobatched JDs break caching (unavoidably?), because the values that the bijector sees inside the vmap are from a different graph than those returned to the user. But `sample_and_log_prob` still works in this case for efficient VI. Specialized implementations for specific distributions will be added in a future change. PiperOrigin-RevId: 374990765
1 parent 4d8ce35 commit fd8be3c

File tree

5 files changed

+171
-36
lines changed

5 files changed

+171
-36
lines changed

tensorflow_probability/python/distributions/distribution.py

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,7 @@ def _log_prob(self, value):
449449
- `_default_event_space_bijector`.
450450
- `_parameter_properties` (to support automatic batch shape derivation,
451451
batch slicing and other features).
452+
- `_sample_and_log_prob`.
452453
453454
Note that subclasses of existing Distributions that redefine `__init__` do
454455
*not* automatically inherit
@@ -1166,22 +1167,21 @@ def _sample_n(self, n, seed=None, **kwargs):
11661167
raise NotImplementedError('sample_n is not implemented: {}'.format(
11671168
type(self).__name__))
11681169

1169-
def _call_sample_n(self, sample_shape, seed, name, **kwargs):
1170+
def _call_sample_n(self, sample_shape, seed, **kwargs):
11701171
"""Wrapper around _sample_n."""
1171-
with self._name_and_control_scope(name):
1172-
if JAX_MODE and seed is None:
1173-
raise ValueError('Must provide JAX PRNGKey as `dist.sample(seed=.)`')
1174-
sample_shape = ps.convert_to_shape_tensor(
1175-
ps.cast(sample_shape, tf.int32), name='sample_shape')
1176-
sample_shape, n = self._expand_sample_shape_to_vector(
1177-
sample_shape, 'sample_shape')
1178-
samples = self._sample_n(
1179-
n, seed=seed() if callable(seed) else seed, **kwargs)
1180-
batch_event_shape = ps.shape(samples)[1:]
1181-
final_shape = ps.concat([sample_shape, batch_event_shape], 0)
1182-
samples = tf.reshape(samples, final_shape)
1183-
samples = self._set_sample_static_shape(samples, sample_shape)
1184-
return samples
1172+
if JAX_MODE and seed is None:
1173+
raise ValueError('Must provide JAX PRNGKey as `dist.sample(seed=.)`')
1174+
sample_shape = ps.convert_to_shape_tensor(
1175+
ps.cast(sample_shape, tf.int32), name='sample_shape')
1176+
sample_shape, n = self._expand_sample_shape_to_vector(
1177+
sample_shape, 'sample_shape')
1178+
samples = self._sample_n(
1179+
n, seed=seed() if callable(seed) else seed, **kwargs)
1180+
batch_event_shape = ps.shape(samples)[1:]
1181+
final_shape = ps.concat([sample_shape, batch_event_shape], 0)
1182+
samples = tf.reshape(samples, final_shape)
1183+
samples = self._set_sample_static_shape(samples, sample_shape)
1184+
return samples
11851185

11861186
def sample(self, sample_shape=(), seed=None, name='sample', **kwargs):
11871187
"""Generate samples of the specified shape.
@@ -1198,7 +1198,62 @@ def sample(self, sample_shape=(), seed=None, name='sample', **kwargs):
11981198
Returns:
11991199
samples: a `Tensor` with prepended dimensions `sample_shape`.
12001200
"""
1201-
return self._call_sample_n(sample_shape, seed, name, **kwargs)
1201+
with self._name_and_control_scope(name):
1202+
return self._call_sample_n(sample_shape, seed, **kwargs)
1203+
1204+
def _call_sample_and_log_prob(self, sample_shape, seed, **kwargs):
1205+
"""Wrapper around `_sample_and_log_prob`."""
1206+
if hasattr(self, '_sample_and_log_prob'):
1207+
sample_shape = ps.convert_to_shape_tensor(
1208+
ps.cast(sample_shape, tf.int32), name='sample_shape')
1209+
return self._sample_and_log_prob(
1210+
distribution_util.expand_to_vector(
1211+
sample_shape, tensor_name='sample_shape'),
1212+
seed=seed, **kwargs)
1213+
1214+
# Naive default implementation. This calls private, rather than public,
1215+
# methods, to avoid duplicating the name_and_control_scope.
1216+
value = self._call_sample_n(sample_shape, seed=seed, **kwargs)
1217+
if hasattr(self, '_log_prob'):
1218+
log_prob = self._log_prob(value, **kwargs)
1219+
elif hasattr(self, '_prob'):
1220+
log_prob = tf.math.log(self._prob(value, **kwargs))
1221+
else:
1222+
raise NotImplementedError('log_prob is not implemented: {}'.format(
1223+
type(self).__name__))
1224+
return value, log_prob
1225+
1226+
def experimental_sample_and_log_prob(self, sample_shape=(), seed=None,
1227+
name='sample_and_log_prob', **kwargs):
1228+
"""Samples from this distribution and returns the log density of the sample.
1229+
1230+
The default implementation simply calls `sample` and `log_prob`:
1231+
1232+
```
1233+
def _sample_and_log_prob(self, sample_shape, seed, **kwargs):
1234+
x = self.sample(sample_shape=sample_shape, seed=seed, **kwargs)
1235+
return x, self.log_prob(x, **kwargs)
1236+
```
1237+
1238+
However, some subclasses may provide more efficient and/or numerically
1239+
stable implementations.
1240+
1241+
Args:
1242+
sample_shape: integer `Tensor` desired shape of samples to draw.
1243+
Default value: `()`.
1244+
seed: Python integer or `tfp.util.SeedStream` instance, for seeding PRNG.
1245+
Default value: `None`.
1246+
name: name to give to the op.
1247+
Default value: `'sample_and_log_prob'`.
1248+
**kwargs: Named arguments forwarded to subclass implementation.
1249+
Returns:
1250+
samples: a `Tensor`, or structure of `Tensor`s, with prepended dimensions
1251+
`sample_shape`.
1252+
log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
1253+
values of type `self.dtype`.
1254+
"""
1255+
with self._name_and_control_scope(name):
1256+
return self._call_sample_and_log_prob(sample_shape, seed=seed, **kwargs)
12021257

12031258
def _call_log_prob(self, value, name, **kwargs):
12041259
"""Wrapper around _log_prob."""

tensorflow_probability/python/distributions/distribution_properties_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,32 @@ def testDistribution(self, dist_name, data):
288288
self.assertAllEqual(s1, s2)
289289

290290

291+
@test_util.test_all_tf_execution_regimes
292+
class SampleAndLogProbTest(test_util.TestCase):
293+
294+
@parameterized.named_parameters(
295+
{'testcase_name': dname, 'dist_name': dname}
296+
for dname in sorted(list(dhps.INSTANTIABLE_BASE_DISTS.keys()) +
297+
list(dhps.INSTANTIABLE_META_DISTS)))
298+
@hp.given(hps.data())
299+
@tfp_hps.tfp_hp_settings()
300+
def testDistribution(self, dist_name, data):
301+
dist = data.draw(dhps.distributions(dist_name=dist_name, enable_vars=False,
302+
validate_args=False))
303+
seed = test_util.test_seed(sampler_type='stateless')
304+
sample_shape = [2, 1]
305+
with tfp_hps.no_tf_rank_errors(), kernel_hps.no_pd_errors():
306+
s1, lp1 = dist.experimental_sample_and_log_prob(sample_shape, seed=seed)
307+
s2 = dist.sample(sample_shape, seed=seed)
308+
self.assertAllClose(s1, s2, atol=1e-4)
309+
310+
# Sanity-check the log prob. The actual values may differ arbitrarily (if
311+
# the `sample_and_log_prob` implementation is more stable) or be NaN, but
312+
# they should at least have the same shape.
313+
lp2 = dist.log_prob(s1)
314+
self.assertAllEqual(lp1.shape, lp2.shape)
315+
316+
291317
@test_util.test_all_tf_execution_regimes
292318
class NoNansTest(test_util.TestCase, dhps.TestCase):
293319

tensorflow_probability/python/distributions/joint_distribution.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -727,14 +727,13 @@ def _flat_resolve_names(self, dummy_name='var'):
727727
# tactically implement the `_call_sample_n` redirector. We don't want to
728728
# override the public level because then tfp.layers can't take generic
729729
# `Distribution.sample` as argument for the `convert_to_tensor_fn` parameter.
730-
def _call_sample_n(self, sample_shape, seed, name, value=None, **kwargs):
731-
with self._name_and_control_scope(name):
732-
return self._sample_n(
733-
sample_shape,
734-
seed=seed() if callable(seed) else seed,
735-
value=self._resolve_value(value=value,
736-
allow_partially_specified=True,
737-
**kwargs))
730+
def _call_sample_n(self, sample_shape, seed, value=None, **kwargs):
731+
return self._sample_n(
732+
sample_shape,
733+
seed=seed() if callable(seed) else seed,
734+
value=self._resolve_value(value=value,
735+
allow_partially_specified=True,
736+
**kwargs))
738737

739738
def _execute_model(self,
740739
sample_shape=(),

tensorflow_probability/python/distributions/transformed_distribution.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -318,21 +318,41 @@ def _batch_shape(self):
318318
tf.broadcast_static_shape, tf.nest.flatten(batch_shape))
319319
return batch_shape
320320

321-
def _call_sample_n(self, sample_shape, seed, name, **kwargs):
321+
def _call_sample_n(self, sample_shape, seed, **kwargs):
322322
# We override `_call_sample_n` rather than `_sample_n` so we can ensure that
323323
# the result of `self.bijector.forward` is not modified (and thus caching
324324
# works).
325-
with self._name_and_control_scope(name):
326-
distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
327-
328-
# First, generate samples from the base distribution.
329-
x = self.distribution.sample(sample_shape=sample_shape,
330-
seed=seed,
331-
**distribution_kwargs)
332-
# Apply the bijector's forward transformation. For caching to
333-
# work, it is imperative that this is the last modification to the
334-
# returned result.
335-
return self.bijector.forward(x, **bijector_kwargs)
325+
distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
326+
327+
# First, generate samples from the base distribution.
328+
x = self.distribution.sample(sample_shape=sample_shape,
329+
seed=seed,
330+
**distribution_kwargs)
331+
# Apply the bijector's forward transformation. For caching to
332+
# work, it is imperative that this is the last modification to the
333+
# returned result.
334+
return self.bijector.forward(x, **bijector_kwargs)
335+
336+
def _sample_and_log_prob(self, sample_shape, seed, **kwargs):
337+
if not self.bijector._is_injective: # pylint: disable=protected-access
338+
# Computing log_prob with a non-injective bijector requires an explicit
339+
# inverse to get all points in the inverse image, so we can't get by
340+
# with just doing the forward pass.
341+
return super()._sample_and_log_prob(sample_shape, seed=seed, **kwargs)
342+
343+
distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
344+
x, base_distribution_log_prob = (
345+
self.distribution.experimental_sample_and_log_prob(
346+
sample_shape, seed, **distribution_kwargs))
347+
y = self.bijector.forward(x, **bijector_kwargs)
348+
fldj = self.bijector.forward_log_det_jacobian(
349+
x,
350+
event_ndims=tf.nest.map_structure(
351+
ps.rank_from_shape,
352+
self.distribution.event_shape_tensor()),
353+
**bijector_kwargs)
354+
return y, (base_distribution_log_prob -
355+
tf.cast(fldj, base_distribution_log_prob.dtype))
336356

337357
def _log_prob(self, y, **kwargs):
338358
distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)

tensorflow_probability/python/distributions/transformed_distribution_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,41 @@ def _forward_log_det_jacobian(self, x):
259259
identity_log_normal.log_prob(
260260
identity_log_normal.sample([2, 3], seed=test_util.test_seed()))
261261

262+
def testSampleAndLogprob(self):
263+
class ExpForwardOnly(tfb.Bijector):
264+
265+
def __init__(self):
266+
super(ExpForwardOnly, self).__init__(forward_min_event_ndims=0)
267+
268+
def _forward(self, x):
269+
return tf.exp(x)
270+
271+
def _forward_log_det_jacobian(self, x):
272+
return tf.convert_to_tensor(value=x)
273+
274+
exp_forward_only = ExpForwardOnly()
275+
276+
mu = 3.0
277+
sigma = 0.02
278+
log_normal = tfd.TransformedDistribution(
279+
distribution=tfd.Normal(loc=mu, scale=sigma),
280+
bijector=exp_forward_only)
281+
282+
sample, log_pdf = self.evaluate(log_normal.experimental_sample_and_log_prob(
283+
[2, 3], seed=test_util.test_seed()))
284+
expected_log_pdf = stats.lognorm.logpdf(
285+
sample, s=sigma, scale=np.exp(mu))
286+
self.assertAllClose(expected_log_pdf, log_pdf, rtol=1e-4, atol=0.)
287+
288+
sample, log_pdf = self.evaluate(
289+
log_normal.experimental_sample_and_log_prob(seed=test_util.test_seed()))
290+
expected_log_pdf = stats.lognorm.logpdf(
291+
sample, s=sigma, scale=np.exp(mu))
292+
self.assertAllClose(expected_log_pdf, log_pdf, rtol=1e-4, atol=0.)
293+
294+
sample2 = self.evaluate(log_normal.sample(seed=test_util.test_seed()))
295+
self.assertAllClose(sample, sample2, rtol=1e-4)
296+
262297
def testCachedSamplesInvert(self):
263298
class ExpInverseOnly(tfb.Bijector):
264299

0 commit comments

Comments
 (0)