Skip to content

Commit aa1f360

Browse files
davmretensorflower-gardener
authored andcommitted
Implement sample_and_log_prob for tfd.Sample.
PiperOrigin-RevId: 375540694
1 parent edb9c5f commit aa1f360

File tree

2 files changed

+43
-28
lines changed

2 files changed

+43
-28
lines changed

tensorflow_probability/python/distributions/sample.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -214,28 +214,31 @@ def _event_shape(self):
214214
return tensorshape_util.concatenate(sample_shape,
215215
self.distribution.event_shape)
216216

217-
def _sample_n(self, n, seed, **kwargs):
218-
sample_shape = ps.reshape(self.sample_shape, shape=[-1])
219-
fake_sample_ndims = ps.rank_from_shape(sample_shape)
217+
def _sampling_permutation(self, sample_ndims):
218+
fake_sample_ndims = ps.rank_from_shape(
219+
ps.reshape(self.sample_shape, shape=[-1]))
220220
event_ndims = ps.rank_from_shape(
221221
self.distribution.event_shape_tensor, self.distribution.event_shape)
222222
batch_ndims = ps.rank_from_shape(
223223
self.distribution.batch_shape_tensor, self.distribution.batch_shape)
224-
perm = ps.concat([
225-
[0],
226-
ps.range(1 + fake_sample_ndims,
227-
1 + fake_sample_ndims + batch_ndims,
224+
return ps.concat([
225+
ps.range(sample_ndims),
226+
ps.range(sample_ndims + fake_sample_ndims,
227+
sample_ndims + fake_sample_ndims + batch_ndims,
228228
dtype=tf.int32),
229-
ps.range(1, 1 + fake_sample_ndims, dtype=tf.int32),
230-
ps.range(1 + fake_sample_ndims + batch_ndims,
231-
1 + fake_sample_ndims + batch_ndims + event_ndims,
229+
ps.range(sample_ndims, sample_ndims + fake_sample_ndims,
230+
dtype=tf.int32),
231+
ps.range(sample_ndims + fake_sample_ndims + batch_ndims,
232+
sample_ndims + fake_sample_ndims + batch_ndims + event_ndims,
232233
dtype=tf.int32),
233234
], axis=0)
234-
x = self.distribution.sample(
235-
ps.concat([[n], sample_shape], axis=0),
236-
seed=seed,
237-
**kwargs)
238-
return tf.transpose(a=x, perm=perm)
235+
236+
def _sample_n(self, n, seed, **kwargs):
237+
sample_shape = ps.reshape(self.sample_shape, shape=[-1])
238+
x = self.distribution.sample(ps.concat([[n], sample_shape], axis=0),
239+
seed=seed,
240+
**kwargs)
241+
return tf.transpose(a=x, perm=self._sampling_permutation(sample_ndims=1))
239242

240243
def _sum_fn(self):
241244
if self._experimental_use_kahan_sum:
@@ -259,20 +262,9 @@ def _prepare_for_underlying(self, x):
259262
ps.shape(x),
260263
paddings=[[ps.maximum(0, -d), 0]],
261264
constant_values=1))
262-
ndims = ps.rank(x)
263265
sample_ndims = ps.maximum(0, d)
264-
# (2) Transpose x's dims.
265-
sample_dims = ps.range(0, sample_ndims)
266-
batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims)
267-
extra_sample_dims = ps.range(
268-
sample_ndims + batch_ndims,
269-
sample_ndims + batch_ndims + extra_sample_ndims)
270-
event_dims = ps.range(
271-
sample_ndims + batch_ndims + extra_sample_ndims,
272-
ndims)
273-
perm = ps.concat(
274-
[sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
275-
x = tf.transpose(x, perm=perm)
266+
x = tf.transpose(
267+
x, perm=ps.invert_permutation(self._sampling_permutation(sample_ndims)))
276268
return x, (sample_ndims, extra_sample_ndims, batch_ndims)
277269

278270
def _finish_log_prob(self, lp, aux):
@@ -289,6 +281,21 @@ def _finish_log_prob(self, lp, aux):
289281
axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims)
290282
return self._sum_fn()(lp, axis=axis)
291283

284+
def _sample_and_log_prob(self, sample_shape, seed, **kwargs):
285+
sample_ndims = ps.rank_from_shape(sample_shape)
286+
batch_ndims = ps.rank_from_shape(
287+
self.distribution.batch_shape_tensor,
288+
self.distribution.batch_shape)
289+
extra_sample_shape = ps.reshape(self.sample_shape, shape=[-1])
290+
extra_sample_ndims = ps.rank_from_shape(extra_sample_shape)
291+
x, lp = self.distribution.experimental_sample_and_log_prob(
292+
ps.concat([sample_shape, extra_sample_shape], axis=0), seed=seed,
293+
**kwargs)
294+
return (
295+
tf.transpose(x, perm=self._sampling_permutation(sample_ndims)),
296+
self._finish_log_prob(
297+
lp, aux=(sample_ndims, extra_sample_ndims, batch_ndims)))
298+
292299
def _log_prob(self, x, **kwargs):
293300
x, aux = self._prepare_for_underlying(x)
294301
return self._finish_log_prob(

tensorflow_probability/python/distributions/sample_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ def test_everything_nonscalar(self):
6666
self.assertEqual((6, 1, 3), actual_lp_.shape)
6767
self.assertAllClose(expected_lp_, actual_lp_, atol=0, rtol=1e-3)
6868

69+
def test_sample_and_log_prob(self):
70+
s = tfd.Sample(
71+
tfd.Independent(tfd.Normal(loc=tf.zeros([3, 2]), scale=1), 1), [5, 4],
72+
validate_args=True)
73+
x, lp = s.experimental_sample_and_log_prob([6, 1],
74+
seed=test_util.test_seed())
75+
self.assertAllClose(lp, s.log_prob(x))
76+
6977
def test_mixed_scalar(self):
7078
s = tfd.Sample(tfd.Independent(tfd.Normal(loc=[0., 0], scale=1), 1),
7179
3, validate_args=False)

0 commit comments

Comments
 (0)