Skip to content

Commit 1f2b8e0

Browse files
davmretensorflower-gardener
authored andcommitted
Implement sample_and_log_prob for several meta distributions.
PiperOrigin-RevId: 375493200
1 parent 39d9e61 commit 1f2b8e0

File tree

5 files changed

+33
-0
lines changed

5 files changed

+33
-0
lines changed

tensorflow_probability/python/distributions/blockwise.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,11 @@ def _sample_n(self, n, seed=None):
333333
return self._flatten_and_concat_event(
334334
self._distribution.sample(n, seed=seed))
335335

336+
def _sample_and_log_prob(self, sample_shape, seed):
337+
x, lp = self._distribution.experimental_sample_and_log_prob(
338+
sample_shape, seed=seed)
339+
return self._flatten_and_concat_event(x), lp
340+
336341
def _log_prob(self, x):
337342
return self._distribution.log_prob(self._split_and_reshape_event(x))
338343

tensorflow_probability/python/distributions/gaussian_process.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,15 @@ def _event_shape(self, index_points=None):
532532
def _sample_n(self, n, seed=None, index_points=None):
533533
return self.get_marginal_distribution(index_points).sample(n, seed=seed)
534534

535+
def _sample_and_log_prob(self,
536+
sample_shape,
537+
seed,
538+
index_points=None,
539+
**kwargs):
540+
return self.get_marginal_distribution(
541+
index_points).experimental_sample_and_log_prob(
542+
sample_shape, seed=seed, **kwargs)
543+
535544
def _log_survival_function(self, value, index_points=None):
536545
return self.get_marginal_distribution(
537546
index_points).log_survival_function(value)

tensorflow_probability/python/distributions/independent.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,11 @@ def _sum_fn(self):
280280
return lambda x, axis: tfp_math.reduce_kahan_sum(x, axis).total
281281
return tf.math.reduce_sum
282282

283+
def _sample_and_log_prob(self, sample_shape, seed, **kwargs):
284+
x, lp = self.distribution.experimental_sample_and_log_prob(
285+
sample_shape, seed=seed, **kwargs)
286+
return x, self._reduce(self._sum_fn(), lp)
287+
283288
def _log_prob(self, x, **kwargs):
284289
return self._reduce(
285290
self._sum_fn(), self.distribution.log_prob(x, **kwargs))

tensorflow_probability/python/distributions/matrix_normal_linear_operator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,13 @@ def _sample_n(self, n, seed=None):
213213
samples = self._as_multivariate_normal(loc=loc).sample(n, seed=seed)
214214
return _unvec(samples, self._event_shape_tensor(loc=loc))
215215

216+
def _sample_and_log_prob(self, sample_shape, seed):
217+
loc = tf.convert_to_tensor(self.loc)
218+
x, lp = self._as_multivariate_normal(
219+
loc=loc).experimental_sample_and_log_prob(
220+
sample_shape, seed=seed)
221+
return _unvec(x, self._event_shape_tensor(loc=loc)), lp
222+
216223
def _entropy(self):
217224
return self._as_multivariate_normal().entropy()
218225

tensorflow_probability/python/distributions/matrix_t_linear_operator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,13 @@ def _sample_n(self, n, seed=None):
242242
samples = self._as_multivariate_t(loc=loc).sample(n, seed=seed)
243243
return _unvec(samples, self._event_shape_tensor(loc=loc))
244244

245+
def _sample_and_log_prob(self, sample_shape, seed):
246+
loc = tf.convert_to_tensor(self.loc)
247+
x, lp = self._as_multivariate_t(
248+
loc=loc).experimental_sample_and_log_prob(
249+
sample_shape, seed=seed)
250+
return _unvec(x, self._event_shape_tensor(loc=loc)), lp
251+
245252
def _entropy(self):
246253
return self._as_multivariate_t().entropy()
247254

0 commit comments

Comments
 (0)