Skip to content

Commit 193fc50

Browse files
davmretensorflower-gardener
authored andcommitted
Implement sample_and_log_prob for BatchBroadcast.
PiperOrigin-RevId: 375770302
1 parent 838888a commit 193fc50

File tree

2 files changed

+83
-29
lines changed

2 files changed

+83
-29
lines changed

tensorflow_probability/python/distributions/batch_broadcast.py

Lines changed: 71 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -224,47 +224,90 @@ def _event_shape(self):
224224
def _event_shape_tensor(self):
225225
return self.distribution.event_shape_tensor()
226226

227-
def _sample_n(self, n, seed=None):
227+
def _augment_sample_shape(self, sample_shape):
228+
# Suppose we have:
229+
# - sample shape of `[n]`,
230+
# - underlying distribution batch shape of `[2, 1]`,
231+
# - final broadcast batch shape of `[4, 2, 3]`.
232+
# Then we must draw `sample_shape + [12]` samples, where
233+
# `12 == n_batch // underlying_n_batch`.
228234
batch_shape = self.batch_shape_tensor()
229-
batch_rank = ps.rank_from_shape(batch_shape)
230235
n_batch = ps.reduce_prod(batch_shape)
236+
underlying_batch_shape = self.distribution.batch_shape_tensor()
237+
underlying_n_batch = ps.reduce_prod(underlying_batch_shape)
238+
return ps.concat(
239+
[sample_shape,
240+
[ps.maximum(0, n_batch // underlying_n_batch)]],
241+
axis=0)
242+
243+
def _transpose_and_reshape_result(self, x, sample_shape, event_shape=None):
244+
if event_shape is None:
245+
event_shape = self.event_shape_tensor()
246+
247+
batch_shape = self.batch_shape_tensor()
248+
batch_rank = ps.rank_from_shape(batch_shape)
231249

232250
underlying_batch_shape = self.distribution.batch_shape_tensor()
233251
underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape)
234-
underlying_n_batch = ps.reduce_prod(underlying_batch_shape)
235252

236-
# Left pad underlying shape with any necessary ones.
253+
# Continuing the example from `_augment_sample_shape`, suppose we have:
254+
# - sample shape of `[n]`,
255+
# - underlying distribution batch shape of `[2, 1]`,
256+
# - final broadcast batch shape of `[4, 2, 3]`.
257+
# and have drawn an `x` of shape `[n, 12, 2, 1] + event_shape`, which we
258+
# ultimately want to have shape `[n, 4, 2, 3] + event_shape`.
259+
260+
# First, we reshape to expand out the batch elements:
261+
# `shape_with_doubled_batch == [n] + [4, 1, 3] + [1, 2, 1] + event_shape`,
262+
# where `[1, 2, 1]` is the fully-expanded underlying batch shape, and
263+
# `[4, 1, 3]` is the shape of the elements being added by broadcasting.
237264
underlying_bcast_shp = ps.concat(
238265
[ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)],
239266
dtype=underlying_batch_shape.dtype),
240267
underlying_batch_shape],
241268
axis=0)
242-
243-
# Determine how many underlying samples to produce.
244-
n_bcast_samples = ps.maximum(0, n_batch // underlying_n_batch)
245-
samps = self.distribution.sample([n, n_bcast_samples], seed=seed)
246-
247269
is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp)
270+
x_with_doubled_batch = tf.reshape(
271+
x,
272+
ps.concat([sample_shape,
273+
ps.where(is_dim_bcast, batch_shape, 1),
274+
underlying_bcast_shp,
275+
event_shape], axis=0))
276+
277+
# Next, construct the permutation that interleaves the batch dimensions,
278+
# resulting in samples with shape
279+
# `[n] + [4, 1] + [1, 2] + [3, 1] + event_shape`.
280+
# Note that each interleaved pair of batch dimensions contains exactly one
281+
# dim of size `1` and one of size `>= 1`.
282+
sample_ndims = ps.rank_from_shape(sample_shape)
283+
x_with_interleaved_batch = tf.transpose(
284+
x_with_doubled_batch,
285+
perm=ps.concat([
286+
ps.range(sample_ndims),
287+
sample_ndims + ps.reshape(
288+
ps.stack([ps.range(batch_rank),
289+
ps.range(batch_rank) + batch_rank], axis=-1),
290+
[-1]),
291+
sample_ndims + 2 * batch_rank + ps.range(
292+
ps.rank_from_shape(event_shape))], axis=0))
293+
294+
# Final reshape to remove the spurious `1` dimensions.
295+
return tf.reshape(
296+
x_with_interleaved_batch,
297+
ps.concat([sample_shape, batch_shape, event_shape], axis=0))
248298

249-
event_shape = self.event_shape_tensor()
250-
event_rank = ps.rank_from_shape(event_shape)
251-
shp = ps.concat([[n], ps.where(is_dim_bcast, batch_shape, 1),
252-
underlying_bcast_shp,
253-
event_shape], axis=0)
254-
# Reshape to expand n_bcast_samples and ones-padded underlying_bcast_shp.
255-
samps = tf.reshape(samps, shp)
256-
# Interleave broadcast and underlying axis indices for transpose.
257-
interleaved_batch_axes = ps.reshape(
258-
ps.stack([ps.range(batch_rank),
259-
ps.range(batch_rank) + batch_rank],
260-
axis=-1),
261-
[-1]) + 1
262-
263-
event_axes = ps.range(event_rank) + (1 + 2 * batch_rank)
264-
perm = ps.concat([[0], interleaved_batch_axes, event_axes], axis=0)
265-
samps = tf.transpose(samps, perm=perm)
266-
# Finally, reshape to the fully-broadcast batch shape.
267-
return tf.reshape(samps, ps.concat([[n], batch_shape, event_shape], axis=0))
299+
def _sample_n(self, n, seed=None):
300+
sample_shape = ps.reshape(n, [1])
301+
x = self.distribution.sample(
302+
self._augment_sample_shape(sample_shape), seed=seed)
303+
return self._transpose_and_reshape_result(x, sample_shape=sample_shape)
304+
305+
def _sample_and_log_prob(self, sample_shape, seed):
306+
x, lp = self.distribution.experimental_sample_and_log_prob(
307+
self._augment_sample_shape(sample_shape), seed=seed)
308+
return (self._transpose_and_reshape_result(x, sample_shape),
309+
self._transpose_and_reshape_result(lp, sample_shape,
310+
event_shape=()))
268311

269312
_log_prob = _make_bcast_fn('log_prob', n_event_shapes=0)
270313
_prob = _make_bcast_fn('prob', n_event_shapes=0)

tensorflow_probability/python/distributions/batch_broadcast_test.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def test_sample(self, data):
7676
batch_shape,
7777
dist.event_shape_tensor()],
7878
axis=0)
79-
sample = dist.sample(sample_shape, seed=test_util.test_seed())
79+
sample = dist.sample(sample_shape,
80+
seed=test_util.test_seed(sampler_type='stateless'))
8081
if self.is_static_shape:
8182
self.assertEqual(tf.TensorShape(self.evaluate(sample_batch_event)),
8283
sample.shape)
@@ -89,6 +90,16 @@ def test_sample(self, data):
8990
sample,
9091
atol=.1)
9192

93+
# Check that `sample_and_log_prob` also gives a correctly-shaped sample
94+
# with correct log-prob.
95+
sample2, lp = dist.experimental_sample_and_log_prob(
96+
sample_shape, seed=test_util.test_seed(sampler_type='stateless'))
97+
if self.is_static_shape:
98+
self.assertEqual(tf.TensorShape(self.evaluate(sample_batch_event)),
99+
sample2.shape)
100+
self.assertAllEqual(sample_batch_event, tf.shape(sample2))
101+
self.assertAllClose(lp, dist.log_prob(sample2))
102+
92103
@hp.given(hps.data())
93104
@tfp_hps.tfp_hp_settings(default_max_examples=5)
94105
def test_log_prob(self, data):

0 commit comments

Comments
 (0)