Skip to content

Commit 784bcaa

Browse files
derifativestensorflower-gardener
authored andcommitted
Add posterior_marginals and posterior_mode methods to MixtureSameFamily.
PiperOrigin-RevId: 385042266
1 parent 7062386 commit 784bcaa

File tree

3 files changed

+98
-2
lines changed

3 files changed

+98
-2
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,6 +1418,7 @@ multi_substrate_py_library(
14181418
name = "mixture_same_family",
14191419
srcs = ["mixture_same_family.py"],
14201420
deps = [
1421+
":categorical",
14211422
":distribution",
14221423
":independent",
14231424
# tensorflow dep,

tensorflow_probability/python/distributions/mixture_same_family.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import tensorflow.compat.v2 as tf
2525

26+
from tensorflow_probability.python.distributions import categorical
2627
from tensorflow_probability.python.distributions import distribution
2728
from tensorflow_probability.python.distributions import independent
2829
from tensorflow_probability.python.internal import assert_util
@@ -353,12 +354,30 @@ def _sample_n(self, n, seed):
353354

354355
return ret
355356

356-
def _log_prob(self, x):
357+
def _per_mixture_component_log_prob(self, x):
358+
"""Per mixture component log probability.
359+
360+
Args:
361+
x: A tensor representing observations from the mixture. Must
362+
be broadcastable with the mixture's batch shape.
363+
364+
Returns:
365+
A Tensor representing, for each observation and for each mixture
366+
component, the log joint probability of that mixture component and
367+
the observation. The shape will be equal to the concatenation of (1) the
368+
broadcast shape of the observations and the batch shape, and (2) the
369+
number of mixture components.
370+
"""
357371
x = self._pad_sample_dims(x)
358372
log_prob_x = self.components_distribution.log_prob(x) # [S, B, k]
359373
log_mix_prob = tf.math.log_softmax(
360374
self.mixture_distribution.logits_parameter(), axis=-1) # [B, k]
361-
return tf.reduce_logsumexp(log_prob_x + log_mix_prob, axis=-1) # [S, B]
375+
return log_prob_x + log_mix_prob # [S, B, k]
376+
377+
def _log_prob(self, x, log_joint=None):
378+
if log_joint is None:
379+
log_joint = self._per_mixture_component_log_prob(x)
380+
return tf.reduce_logsumexp(log_joint, axis=-1) # [S, B]
362381

363382
def _mean(self):
364383
probs = self.mixture_distribution.probs_parameter() # [B, k] or [k]
@@ -424,6 +443,48 @@ def _covariance(self):
424443
axis=-3) # [B, E, E]
425444
return mean_cond_var + var_cond_mean # [B, E, E]
426445

446+
def posterior_marginal(self, observations, name='posterior_marginals'):
447+
"""Compute the marginal posterior distribution for a batch of observations.
448+
449+
Note: The behavior of this function is undefined if the `observations`
450+
argument represents impossible observations from the model.
451+
452+
Args:
453+
observations: A tensor representing observations from the mixture. Must
454+
be broadcastable with the mixture's batch shape.
455+
name: A string naming a scope.
456+
457+
Returns:
458+
posterior_marginals: A `Categorical` distribution object representing
459+
the marginal probability of the components of the mixture. The batch
460+
shape of the `Categorical` will be the broadcast shape of `observations`
461+
and the mixture batch shape; the number of classes will equal the
462+
number of mixture components.
463+
"""
464+
with self._name_and_control_scope(name):
465+
return categorical.Categorical(
466+
logits=self._per_mixture_component_log_prob(observations))
467+
468+
def posterior_mode(self, observations, name='posterior_mode'):
469+
"""Compute the posterior mode for a batch of distributions.
470+
471+
Note: The behavior of this function is undefined if the `observations`
472+
argument represents impossible observations from the mixture.
473+
474+
Args:
475+
observations: A tensor representing observations from the mixture. Must
476+
be broadcastable with the mixture's batch shape.
477+
name: A string naming a scope.
478+
479+
Returns:
480+
A Tensor representing the mode (most likely component) for each
481+
observation. The shape will be equal to the broadcast shape of the
482+
observations and the batch shape.
483+
"""
484+
with self._name_and_control_scope(name):
485+
return tf.math.argmax(
486+
self._per_mixture_component_log_prob(observations), axis=-1)
487+
427488
def _pad_sample_dims(self, x, event_ndims=None):
428489
with tf.name_scope('pad_sample_dims'):
429490
if event_ndims is None:

tensorflow_probability/python/distributions/mixture_same_family_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,40 @@ def testVarianceConsistentCovariance(self):
148148
cov_, var_ = self.evaluate([gm.covariance(), gm.variance()])
149149
self.assertAllClose(cov_.diagonal(), var_, atol=0.)
150150

151+
def testPosteriorMarginal(self):
152+
bm = tfd.MixtureSameFamily(
153+
mixture_distribution=tfd.Categorical(
154+
probs=self._build_tensor([0.1, 0.9])),
155+
components_distribution=tfd.Categorical(
156+
probs=self._build_tensor([[.2, .3, .5],
157+
[.7, .2, .1]])),
158+
validate_args=True)
159+
marginal_dist = bm.posterior_marginal(self._build_tensor([0., 1., 2.]))
160+
marginals = self.evaluate(marginal_dist.probs_parameter())
161+
162+
self.assertAllEqual([3, 2], self._shape(marginals))
163+
164+
expected_marginals = [
165+
[(.1*.2)/(.1*.2 + .9*.7), (.9*.7)/(.1*.2 + .9*.7)],
166+
[(.1*.3)/(.1*.3 + .9*.2), (.9*.2)/(.1*.3 + .9*.2)],
167+
[(.1*.5)/(.1*.5 + .9*.1), (.9*.1)/(.1*.5 + .9*.1)]
168+
]
169+
170+
self.assertAllClose(marginals, expected_marginals)
171+
172+
def testPosteriorMode(self):
173+
gm = tfd.MixtureSameFamily(
174+
mixture_distribution=tfd.Categorical(
175+
probs=self._build_tensor([[0.5, 0.5],
176+
[0.01, 0.99]])),
177+
components_distribution=tfd.Normal(
178+
loc=self._build_tensor([[-1., 1.], [-1., 1.]]),
179+
scale=self._build_tensor(1.)))
180+
mode = gm.posterior_mode(
181+
self._build_tensor([[1.], [-1.], [-6.]]))
182+
self.assertAllEqual([3, 2], self._shape(mode))
183+
self.assertAllEqual([[1, 1], [0, 1], [0, 0]], self.evaluate(mode))
184+
151185
def testReparameterizationOfNonReparameterizedComponents(self):
152186
with self.assertRaises(ValueError):
153187
tfd.MixtureSameFamily(

0 commit comments

Comments
 (0)