Skip to content

Commit 50c549a

Browse files
Googlertensorflower-gardener
authored andcommitted
When inflating a continuous distribution, change the log_prob
implementation to one with better maximum likelihood estimation properties. PiperOrigin-RevId: 473308814
1 parent 5f1c6b8 commit 50c549a

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

tensorflow_probability/python/distributions/inflated.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def __init__(self,
102102
lambda logit: tf.stack([logit, -logit], axis=-1),
103103
dtype=self._inflated_loc_logits.dtype,
104104
shape=self._inflated_loc_logits.shape + (2,))
105-
categorical_dist = categorical.Categorical(
105+
self._categorical_dist = categorical.Categorical(
106106
logits=cat_logits,
107107
validate_args=validate_args,
108108
allow_nan_stats=allow_nan_stats)
@@ -116,14 +116,14 @@ def __init__(self,
116116
dtype=self._inflated_loc_probs.dtype,
117117
shape=self._inflated_loc_probs.shape + (2,)
118118
)
119-
categorical_dist = categorical.Categorical(
119+
self._categorical_dist = categorical.Categorical(
120120
probs=cat_probs,
121121
validate_args=validate_args,
122122
allow_nan_stats=allow_nan_stats)
123123
probs_or_logits = self._inflated_loc_probs
124124

125125
super(_Inflated, self).__init__(
126-
cat=categorical_dist,
126+
cat=self._categorical_dist,
127127
components=[
128128
deterministic.Deterministic(
129129
DeferredTensor(
@@ -151,6 +151,24 @@ def _parameter_properties(cls, dtype, num_classes=None):
151151
),
152152
inflated_loc=parameter_properties.ParameterProperties())
153153

154+
def _log_prob(self, x):
155+
# We override the log_prob implementation from Mixture in the case
156+
# where we are inflating a continuous distribution, because we have
157+
# found that this "censored" version gives a good maximum likelihood
158+
# estimate of the continuous distribution's parameters but the
159+
# default implementation doesn't. This follows the proposal in
160+
# https://arxiv.org/pdf/2010.09647.pdf for summing distributions of
161+
# different Hausdorff dimension.
162+
if isinstance(self._distribution,
163+
distribution_lib.DiscreteDistributionMixin):
164+
return super(_Inflated, self)._log_prob(x)
165+
else:
166+
return tf.where(
167+
tf.equal(x, self._inflated_loc),
168+
self._categorical_dist.log_prob(0),
169+
self._categorical_dist.log_prob(1) +
170+
self._distribution.log_prob(x))
171+
154172
@property
155173
def distribution(self):
156174
"""The distribution used for the non-inflated part."""

tensorflow_probability/python/distributions/inflated_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,18 @@ def test_inflated_batched(self):
5959
samples = zinb.sample(seed=test_util.test_seed())
6060
self.assertEqual((5,), samples.shape)
6161

62+
def test_inflated_continuous_log_prob(self):
63+
spike_and_slab = inflated.Inflated(
64+
normal.Normal(loc=1.0, scale=2.0), inflated_loc_probs=0.1)
65+
self.assertEqual(self.evaluate(tf.math.log(0.1)),
66+
self.evaluate(spike_and_slab.log_prob(0.0)))
67+
self.assertNear(
68+
self.evaluate(tf.math.log(0.9) + normal.Normal(
69+
loc=1.0, scale=2.0).log_prob(2.0)),
70+
self.evaluate(spike_and_slab.log_prob(2.0)),
71+
1e-6
72+
)
73+
6274
def test_inflated_factory(self):
6375
spike_and_slab_class = inflated.inflated_factory('SpikeAndSlab',
6476
normal.Normal, 0.0)

0 commit comments

Comments
 (0)