Skip to content

Commit e1d930f

Browse files
davmretensorflower-gardener
authored andcommitted
Implement sample_and_log_prob for BatchConcat and BatchReshape.
PiperOrigin-RevId: 375612805
1 parent 7555d90 commit e1d930f

File tree

4 files changed

+41
-22
lines changed

4 files changed

+41
-22
lines changed

tensorflow_probability/python/distributions/batch_concat.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,20 @@ def _sample_n(self, n, seed=None, **kwargs):
423423

424424
return tf.concat(samples, axis=self._axis+1)
425425

426+
def _sample_and_log_prob(self, sample_shape, seed, **kwargs):
427+
all_seeds = samplers.split_seed(
428+
seed, len(self._distributions), salt='BatchConcat')
429+
samples = []
430+
log_probs = []
431+
for d, s in zip(self._distributions, all_seeds):
432+
x, lp = d.experimental_sample_and_log_prob(sample_shape, s)
433+
samples.append(self._broadcast(x, sample_shape))
434+
log_probs.append(self._broadcast(lp, sample_shape))
435+
436+
sample_shape_size = ps.rank_from_shape(sample_shape)
437+
return (tf.concat(samples, axis=self._axis + sample_shape_size),
438+
tf.concat(log_probs, axis=self._axis + sample_shape_size))
439+
426440
def _call_split_concat(self, fn, x, **kwargs):
427441
sample_shape_size, split_x = self._split_sample(x)
428442
result = [

tensorflow_probability/python/distributions/batch_concat_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ def test_sample(self):
128128
samples = concat_dist.sample([12, 20], seed=seed)
129129
self.assertAllEqual(self.evaluate(tf.shape(samples)), [12, 20, 2, 6, 4, 2])
130130

131+
def test_sample_and_log_prob(self):
132+
concat_dist = self.get_distributions()
133+
seed = test_util.test_seed()
134+
samples, lp = concat_dist.experimental_sample_and_log_prob(seed=seed)
135+
self.assertAllEqual(self.evaluate(tf.shape(samples)), [2, 6, 4, 2])
136+
self.assertAllClose(lp, concat_dist.log_prob(samples))
137+
131138
def test_split_sample(self):
132139
concat_dist = self.get_distributions()
133140
x_sample = tf.ones([2, 6, 4, 2])

tensorflow_probability/python/distributions/batch_reshape.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,16 @@ def _sample_n(self, n, seed=None, **kwargs):
211211
axis=0)
212212
return tf.reshape(x, new_shape)
213213

214+
def _sample_and_log_prob(self, sample_shape, seed=None, **kwargs):
215+
x, lp = self.distribution.experimental_sample_and_log_prob(
216+
sample_shape=sample_shape, seed=seed, **kwargs)
217+
return (tf.reshape(x, tf.concat([sample_shape,
218+
self._batch_shape_unexpanded,
219+
self.event_shape_tensor()], axis=0)),
220+
tf.reshape(lp, tf.concat([sample_shape,
221+
self._batch_shape_unexpanded],
222+
axis=0)))
223+
214224
def _log_prob(self, x, **kwargs):
215225
return self._call_reshape_input_output(
216226
self.distribution.log_prob, x, extra_kwargs=kwargs)

tensorflow_probability/python/distributions/batch_reshape_test.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def test_scalar_variate_sample_and_log_prob(self):
200200
# document that the test is not intended to run in eager mode.
201201
return
202202

203-
seed = test_util.test_seed()
203+
seed = test_util.test_seed(sampler_type='stateless')
204204

205205
new_batch_shape = [2, 2]
206206
old_batch_shape = [4]
@@ -220,27 +220,15 @@ def test_scalar_variate_sample_and_log_prob(self):
220220
expected_log_prob = tf.reshape(normal.log_prob(x), expected_log_prob_shape)
221221
actual_log_prob = reshape_normal.log_prob(expected_sample)
222222

223-
[
224-
batch_shape_,
225-
event_shape_,
226-
expected_sample_,
227-
actual_sample_,
228-
expected_log_prob_,
229-
actual_log_prob_,
230-
] = self.evaluate([
231-
batch_shape,
232-
event_shape,
233-
expected_sample,
234-
actual_sample,
235-
expected_log_prob,
236-
actual_log_prob,
237-
])
238-
self.assertAllEqual(new_batch_shape, batch_shape_)
239-
self.assertAllEqual([], event_shape_)
240-
self.assertAllClose(expected_sample_, actual_sample_,
241-
atol=0., rtol=1e-6)
242-
self.assertAllClose(expected_log_prob_, actual_log_prob_,
243-
atol=0., rtol=1e-6)
223+
self.assertAllEqual(new_batch_shape, batch_shape)
224+
self.assertAllEqual([], event_shape)
225+
self.assertAllClose(expected_sample, actual_sample, atol=0., rtol=1e-6)
226+
self.assertAllClose(expected_log_prob, actual_log_prob, atol=0., rtol=1e-6)
227+
228+
slp_sample, slp_lp = reshape_normal.experimental_sample_and_log_prob(
229+
seed=seed)
230+
self.assertAllClose(expected_sample, slp_sample, atol=0., rtol=1e-6)
231+
self.assertAllClose(expected_log_prob, slp_lp, atol=0., rtol=1e-6)
244232
if not self.is_static_shape:
245233
return
246234
self.assertAllEqual(new_batch_shape, reshape_normal.batch_shape)

0 commit comments

Comments
 (0)