Skip to content

Commit 82ec0fe

Browse files
Fabian Mentzercopybara-github
authored andcommitted
Fix instability in mixture distributions if log_prob becomes -inf.
PiperOrigin-RevId: 516755689 Change-Id: I866976f926256fcef583313f734c055e60ae452d
1 parent e2879ec commit 82ec0fe

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

tensorflow_compression/python/distributions/uniform_noise.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@ def _logsum_expbig_minus_expsmall(big, small):
4444
`tf.Tensor` containing the result.
4545
"""
4646
with tf.name_scope("logsum_expbig_minus_expsmall"):
47-
return tf.math.log1p(-tf.exp(small - big)) + big
47+
# Have to special case `inf` and `-inf` since otherwise we get a NaN
48+
# out of the exp (if both small and big are -inf).
49+
return tf.where(
50+
tf.math.is_inf(big), big, tf.math.log1p(-tf.exp(small - big)) + big
51+
)
4852

4953

5054
class UniformNoiseAdapter(tfp.distributions.Distribution):

tensorflow_compression/python/distributions/uniform_noise_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,15 @@ def test_stats_throw_error(self):
161161
with self.assertRaises(NotImplementedError):
162162
dist.survival_function(.5)
163163

164+
def test_stable(self):
165+
# An extreme distribution that has probability 1 at 0 and probability 0
166+
# otherwise.
167+
dist = self.dist_cls(loc=[0, 0], scale=[0, 0], weight=[0.5, 0.5])
168+
with self.subTest("AtTheMode"):
169+
self.assertAllClose(dist.prob([0]), [1.0])
170+
with self.subTest("NotAtTheMode"):
171+
self.assertAllClose(dist.prob([1]), [0.0])
172+
164173

165174
class NoisyNormalMixtureTest(MixtureTest, tf.test.TestCase):
166175

0 commit comments

Comments
 (0)