Skip to content

Commit 4bc1db8

Browse files
Make the .log_prob of an inflated continuous distribution differentiable
in the case where the continuous distribution isn't differentiable at the inflated location. (Like at a zero inflated LogNormal). Also, add inflated_loc_atol and inflated_loc_rtol parameters to Inflated. PiperOrigin-RevId: 475386917
1 parent 504f5a4 commit 4bc1db8

File tree

3 files changed

+87
-10
lines changed

3 files changed

+87
-10
lines changed

tensorflow_probability/python/distributions/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,7 @@ multi_substrate_py_library(
10141014
"//tensorflow_probability/python/internal:auto_composite_tensor",
10151015
"//tensorflow_probability/python/internal:dtype_util",
10161016
"//tensorflow_probability/python/internal:parameter_properties",
1017+
"//tensorflow_probability/python/internal:samplers",
10171018
"//tensorflow_probability/python/internal:tensor_util",
10181019
"//tensorflow_probability/python/util:deferred_tensor",
10191020
],
@@ -3255,13 +3256,16 @@ multi_substrate_py_test(
32553256
name = "inflated_test",
32563257
srcs = ["inflated_test.py"],
32573258
deps = [
3259+
":gamma",
32583260
":inflated",
3261+
":lognormal",
32593262
":negative_binomial",
32603263
":normal",
32613264
# numpy dep,
32623265
# tensorflow dep,
32633266
"//tensorflow_probability/python/experimental/util",
32643267
"//tensorflow_probability/python/internal:test_util",
3268+
"//tensorflow_probability/python/math:gradient",
32653269
],
32663270
)
32673271

tensorflow_probability/python/distributions/inflated.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,18 @@
2828
from tensorflow_probability.python.internal import auto_composite_tensor
2929
from tensorflow_probability.python.internal import dtype_util
3030
from tensorflow_probability.python.internal import parameter_properties
31+
from tensorflow_probability.python.internal import samplers
3132
from tensorflow_probability.python.internal import tensor_util
3233
from tensorflow_probability.python.util.deferred_tensor import DeferredTensor
3334

3435
__all__ = ['Inflated', 'inflated_factory', 'ZeroInflatedNegativeBinomial']
3536

3637

38+
def _safe_value_for_distribution(dist):
39+
"""Returns an x for which it is safe to differentiate dist.logprob(x)."""
40+
return dist.sample(seed=samplers.zeros_seed())
41+
42+
3743
class _Inflated(mixture.Mixture):
3844
"""A mixture of a point-mass and another distribution.
3945
@@ -53,6 +59,8 @@ def __init__(self,
5359
inflated_loc_logits=None,
5460
inflated_loc_probs=None,
5561
inflated_loc=0.0,
62+
inflated_loc_atol=None,
63+
inflated_loc_rtol=None,
5664
validate_args=False,
5765
allow_nan_stats=True,
5866
name='Inflated'):
@@ -71,6 +79,12 @@ def __init__(self,
7179
`inflated_loc_logits` should be passed in.
7280
inflated_loc: A scalar or tensor containing the locations of the point
7381
mass component of the mixture.
82+
inflated_loc_atol: Non-negative `Tensor` of same `dtype` as
83+
`inflated_loc` and broadcastable shape. The absolute tolerance for
84+
comparing closeness to `inflated_loc`. Default is `0`.
85+
inflated_loc_rtol: Non-negative `Tensor` of same `dtype` as
86+
`inflated_loc` and broadcastable shape. The relative tolerance for
87+
comparing closeness to `inflated_loc`. Default is `0`.
7488
validate_args: If true, inconsistent batch or event sizes raise a runtime
7589
error.
7690
allow_nan_stats: If false, any undefined statistics for any batch memeber
@@ -95,6 +109,12 @@ def __init__(self,
95109
inflated_loc_probs, dtype=dtype, name='inflated_loc_probs')
96110
self._inflated_loc = tensor_util.convert_nonref_to_tensor(
97111
inflated_loc, dtype=dtype, name='inflated_loc')
112+
self._inflated_loc_atol = tensor_util.convert_nonref_to_tensor(
113+
0 if inflated_loc_atol is None else inflated_loc_atol,
114+
dtype=dtype, name='inflated_loc_atol')
115+
self._inflated_loc_rtol = tensor_util.convert_nonref_to_tensor(
116+
0 if inflated_loc_rtol is None else inflated_loc_rtol,
117+
dtype=dtype, name='inflated_loc_rtol')
98118

99119
if inflated_loc_probs is None:
100120
cat_logits = DeferredTensor(
@@ -122,17 +142,23 @@ def __init__(self,
122142
allow_nan_stats=allow_nan_stats)
123143
probs_or_logits = self._inflated_loc_probs
124144

145+
self._deterministic = deterministic.Deterministic(
146+
DeferredTensor(
147+
probs_or_logits,
148+
# pylint: disable=g-long-lambda
149+
lambda _: tf.broadcast_to(self._inflated_loc,
150+
tf.shape(probs_or_logits))
151+
# pylint: enable=g-long-lambda
152+
),
153+
atol=self._inflated_loc_atol,
154+
rtol=self._inflated_loc_rtol,
155+
validate_args=validate_args,
156+
allow_nan_stats=allow_nan_stats)
157+
125158
super(_Inflated, self).__init__(
126159
cat=self._categorical_dist,
127160
components=[
128-
deterministic.Deterministic(
129-
DeferredTensor(
130-
probs_or_logits,
131-
lambda x: tf.constant( # pylint: disable=g-long-lambda
132-
inflated_loc, dtype=distribution.dtype,
133-
shape=probs_or_logits.shape)),
134-
validate_args=validate_args,
135-
allow_nan_stats=allow_nan_stats),
161+
self._deterministic,
136162
distribution
137163
],
138164
validate_args=validate_args,
@@ -151,6 +177,12 @@ def _parameter_properties(cls, dtype, num_classes=None):
151177
),
152178
inflated_loc=parameter_properties.ParameterProperties())
153179

180+
def _almost_inflated_loc(self, x):
181+
# pylint: disable=protected-access
182+
return tf.abs(x - self._inflated_loc) <= self._deterministic._slack(
183+
self._inflated_loc)
184+
# pylint: enable=protected-access
185+
154186
def _log_prob(self, x):
155187
# We override the log_prob implementation from Mixture in the case
156188
# where we are inflating a continuous distribution, because we have
@@ -163,11 +195,19 @@ def _log_prob(self, x):
163195
distribution_lib.DiscreteDistributionMixin):
164196
return super(_Inflated, self)._log_prob(x)
165197
else:
198+
# Enable non-NaN gradients of the log_prob, even if the gradient of
199+
# the continuous distribution is NaN at _inflated_loc. See
200+
# https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
201+
# for details.
202+
safe_x = tf.where(
203+
self._almost_inflated_loc(x),
204+
_safe_value_for_distribution(self._distribution),
205+
x)
166206
return tf.where(
167-
tf.equal(x, self._inflated_loc),
207+
self._almost_inflated_loc(x),
168208
self._categorical_dist.log_prob(0),
169209
self._categorical_dist.log_prob(1) +
170-
self._distribution.log_prob(x))
210+
self._distribution.log_prob(safe_x))
171211

172212
@property
173213
def distribution(self):

tensorflow_probability/python/distributions/inflated_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
import numpy as np
1717
import tensorflow.compat.v2 as tf
1818

19+
from tensorflow_probability.python.distributions import gamma
1920
from tensorflow_probability.python.distributions import inflated
21+
from tensorflow_probability.python.distributions import lognormal
2022
from tensorflow_probability.python.distributions import negative_binomial
2123
from tensorflow_probability.python.distributions import normal
2224
from tensorflow_probability.python.experimental import util
2325
from tensorflow_probability.python.experimental.util import trainable
2426
from tensorflow_probability.python.internal import test_util
27+
from tensorflow_probability.python.math import gradient
2528

2629

2730
class DistributionsTest(test_util.TestCase):
@@ -114,6 +117,36 @@ def test_zinb_as_composite_tensor(self):
114117
comp_zinb = util.as_composite(zinb)
115118
unused_as_tensors = tf.nest.flatten(comp_zinb)
116119

120+
@test_util.disable_test_for_backend(
121+
disable_numpy=True,
122+
reason='Only TF has gradient tape')
123+
def test_safe_value_for_distribution(self):
124+
x = self.evaluate(inflated._safe_value_for_distribution(
125+
gamma.Gamma(concentration=3.0, rate=2.0)))
126+
lp, grad = gradient.value_and_gradient(
127+
lambda p: gamma.Gamma(concentration=p, rate=2.0).log_prob(x),
128+
3.0)
129+
self.assertAllFinite(lp)
130+
self.assertAllFinite(grad)
131+
132+
@test_util.disable_test_for_backend(
133+
disable_numpy=True,
134+
reason='Only TF has gradient tape')
135+
def test_log_prob_for_inflated_lognormal_is_diffable(self):
136+
x = tf.constant([0.0, 1.0])
137+
138+
# pylint: disable=g-long-lambda
139+
lp, grad = gradient.value_and_gradient(
140+
lambda loc: inflated.Inflated(
141+
lognormal.LogNormal(loc=loc, scale=1.0),
142+
inflated_loc_probs=0.5,
143+
).log_prob(x),
144+
5.0,
145+
)
146+
# pylint: enable=g-long-lambda
147+
self.assertAllFinite(lp)
148+
self.assertAllFinite(grad)
149+
117150

118151
if __name__ == '__main__':
119152
test_util.main()

0 commit comments

Comments
 (0)