|
23 | 23 |
|
24 | 24 | import tensorflow.compat.v2 as tf
|
25 | 25 |
|
| 26 | +from tensorflow_probability.python.distributions import categorical |
26 | 27 | from tensorflow_probability.python.distributions import distribution
|
27 | 28 | from tensorflow_probability.python.distributions import independent
|
28 | 29 | from tensorflow_probability.python.internal import assert_util
|
@@ -353,12 +354,30 @@ def _sample_n(self, n, seed):
|
353 | 354 |
|
354 | 355 | return ret
|
355 | 356 |
|
356 |
| - def _log_prob(self, x): |
| 357 | + def _per_mixture_component_log_prob(self, x): |
| 358 | + """Per mixture component log probability. |
| 359 | +
|
| 360 | + Args: |
| 361 | + x: A tensor representing observations from the mixture. Must |
| 362 | + be broadcastable with the mixture's batch shape. |
| 363 | +
|
| 364 | + Returns: |
| 365 | + A Tensor representing, for each observation and for each mixture |
| 366 | + component, the log joint probability of that mixture component and |
| 367 | + the observation. The shape will be equal to the concatenation of (1) the |
| 368 | + broadcast shape of the observations and the batch shape, and (2) the |
| 369 | + number of mixture components. |
| 370 | + """ |
357 | 371 | x = self._pad_sample_dims(x)
|
358 | 372 | log_prob_x = self.components_distribution.log_prob(x) # [S, B, k]
|
359 | 373 | log_mix_prob = tf.math.log_softmax(
|
360 | 374 | self.mixture_distribution.logits_parameter(), axis=-1) # [B, k]
|
361 |
| - return tf.reduce_logsumexp(log_prob_x + log_mix_prob, axis=-1) # [S, B] |
| 375 | + return log_prob_x + log_mix_prob # [S, B, k] |
| 376 | + |
| 377 | + def _log_prob(self, x, log_joint=None): |
| 378 | + if log_joint is None: |
| 379 | + log_joint = self._per_mixture_component_log_prob(x) |
| 380 | + return tf.reduce_logsumexp(log_joint, axis=-1) # [S, B] |
362 | 381 |
|
363 | 382 | def _mean(self):
|
364 | 383 | probs = self.mixture_distribution.probs_parameter() # [B, k] or [k]
|
@@ -424,6 +443,48 @@ def _covariance(self):
|
424 | 443 | axis=-3) # [B, E, E]
|
425 | 444 | return mean_cond_var + var_cond_mean # [B, E, E]
|
426 | 445 |
|
| 446 | + def posterior_marginal(self, observations, name='posterior_marginals'): |
| 447 | + """Compute the marginal posterior distribution for a batch of observations. |
| 448 | +
|
| 449 | + Note: The behavior of this function is undefined if the `observations` |
| 450 | + argument represents impossible observations from the model. |
| 451 | +
|
| 452 | + Args: |
| 453 | + observations: A tensor representing observations from the mixture. Must |
| 454 | + be broadcastable with the mixture's batch shape. |
| 455 | + name: A string naming a scope. |
| 456 | +
|
| 457 | + Returns: |
| 458 | + posterior_marginals: A `Categorical` distribution object representing |
| 459 | + the marginal probability of the components of the mixture. The batch |
| 460 | + shape of the `Categorical` will be the broadcast shape of `observations` |
| 461 | + and the mixture batch shape; the number of classes will equal the |
| 462 | + number of mixture components. |
| 463 | + """ |
| 464 | + with self._name_and_control_scope(name): |
| 465 | + return categorical.Categorical( |
| 466 | + logits=self._per_mixture_component_log_prob(observations)) |
| 467 | + |
| 468 | + def posterior_mode(self, observations, name='posterior_mode'): |
| 469 | + """Compute the posterior mode for a batch of distributions. |
| 470 | +
|
| 471 | + Note: The behavior of this function is undefined if the `observations` |
| 472 | + argument represents impossible observations from the mixture. |
| 473 | +
|
| 474 | + Args: |
| 475 | + observations: A tensor representing observations from the mixture. Must |
| 476 | + be broadcastable with the mixture's batch shape. |
| 477 | + name: A string naming a scope. |
| 478 | +
|
| 479 | + Returns: |
| 480 | + A Tensor representing the mode (most likely component) for each |
| 481 | + observation. The shape will be equal to the broadcast shape of the |
| 482 | + observations and the batch shape. |
| 483 | + """ |
| 484 | + with self._name_and_control_scope(name): |
| 485 | + return tf.math.argmax( |
| 486 | + self._per_mixture_component_log_prob(observations), axis=-1) |
| 487 | + |
427 | 488 | def _pad_sample_dims(self, x, event_ndims=None):
|
428 | 489 | with tf.name_scope('pad_sample_dims'):
|
429 | 490 | if event_ndims is None:
|
|
0 commit comments