Skip to content

Commit 12d1809

Browse files
Googlertensorflower-gardener
authored andcommitted
Correctly compute Categorical.entropy for large negative float32 logits.
PiperOrigin-RevId: 381039181
1 parent d3398a7 commit 12d1809

File tree

2 files changed

+23
-19
lines changed

2 files changed

+23
-19
lines changed

tensorflow_probability/python/distributions/categorical.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -303,31 +303,23 @@ def _entropy(self):
303303
return -tf.reduce_sum(
304304
tf.math.multiply_no_nan(tf.math.log(probs), probs),
305305
axis=-1)
306-
# The following result can be derived as follows. Write log(p[i]) as:
307-
# s[i]-m-lse(s[i]-m) where m=max(s), then you have:
308-
# sum_i exp(s[i]-m-lse(s-m)) (s[i] - m - lse(s-m))
309-
# = -m - lse(s-m) + sum_i s[i] exp(s[i]-m-lse(s-m))
310-
# = -m - lse(s-m) + (1/exp(lse(s-m))) sum_i s[i] exp(s[i]-m)
311-
# = -m - lse(s-m) + (1/sumexp(s-m)) sum_i s[i] exp(s[i]-m)
312-
# Write x[i]=s[i]-m then you have:
313-
# = -m - lse(x) + (1/sum_exp(x)) sum_i s[i] exp(x[i])
314-
# Negating all of this result is the Shanon (discrete) entropy.
306+
# The following result can be derived as follows. Let s[i] be a logit.
307+
# The entropy is:
308+
# H = -sum_i(p[i] * log(p[i]))
309+
# = -sum_i(p[i] * (s[i] - logsumexp(s))
310+
# = logsumexp(s) - sum_i(p[i] * s[i])
315311
logits = tf.convert_to_tensor(self._logits)
316-
m = tf.reduce_max(logits, axis=-1, keepdims=True)
317-
x = logits - m
318-
sum_exp_x = tf.reduce_sum(tf.math.exp(x), axis=-1)
319-
lse_logits = m[..., 0] + tf.math.log(sum_exp_x)
312+
logits = logits - tf.reduce_max(logits, axis=-1, keepdims=True)
313+
lse_logits = tf.reduce_logsumexp(logits, axis=-1)
314+
320315
# TODO(b/161014180): Workaround to support correct gradient calculations
321316
# with -inf logits.
322-
is_inf_logits = tf.cast(tf.math.is_inf(logits), dtype=tf.float32)
323-
is_negative_logits = tf.cast(logits < 0, dtype=tf.float32)
324317
masked_logits = tf.where(
325-
tf.cast((is_inf_logits * is_negative_logits), dtype=bool),
318+
(tf.math.is_inf(logits) & (logits < 0)),
326319
tf.cast(1.0, dtype=logits.dtype), logits)
327-
328320
return lse_logits - tf.reduce_sum(
329-
tf.math.multiply_no_nan(masked_logits, tf.math.exp(x)),
330-
axis=-1) / sum_exp_x
321+
tf.math.multiply_no_nan(masked_logits, tf.math.exp(logits)),
322+
axis=-1) / tf.math.exp(lse_logits)
331323

332324
def _mode(self):
333325
x = self._probs if self._logits is None else self._logits

tensorflow_probability/python/distributions/categorical_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,18 @@ def testEntropyWithNegInfLogits(self):
393393
ans = [-(0.5*np.log(0.5) + 0.5*np.log(0.5)), -(np.log(1))]
394394
self.assertAllClose(self.evaluate(dist_entropy), ans)
395395

396+
def testEntropyWithLargeNegLogits(self):
397+
num_categories = 11
398+
logits = np.array([
399+
[-1e7] * num_categories,
400+
[-1e8] * num_categories,
401+
[-1e9] * num_categories], dtype=np.float32)
402+
dist = tfd.Categorical(logits=logits, validate_args=True)
403+
dist_entropy = dist.entropy()
404+
405+
ans = [np.log(num_categories)] * 3
406+
self.assertAllClose(self.evaluate(dist_entropy), ans)
407+
396408
def testSample(self):
397409
histograms = np.array([[[0.2, 0.8], [0.4, 0.6]]])
398410
dist = tfd.Categorical(tf.math.log(histograms) - 50., validate_args=True)

0 commit comments

Comments
 (0)