@@ -303,31 +303,23 @@ def _entropy(self):
303
303
return - tf .reduce_sum (
304
304
tf .math .multiply_no_nan (tf .math .log (probs ), probs ),
305
305
axis = - 1 )
306
- # The following result can be derived as follows. Write log(p[i]) as:
307
- # s[i]-m-lse(s[i]-m) where m=max(s), then you have:
308
- # sum_i exp(s[i]-m-lse(s-m)) (s[i] - m - lse(s-m))
309
- # = -m - lse(s-m) + sum_i s[i] exp(s[i]-m-lse(s-m))
310
- # = -m - lse(s-m) + (1/exp(lse(s-m))) sum_i s[i] exp(s[i]-m)
311
- # = -m - lse(s-m) + (1/sumexp(s-m)) sum_i s[i] exp(s[i]-m)
312
- # Write x[i]=s[i]-m then you have:
313
- # = -m - lse(x) + (1/sum_exp(x)) sum_i s[i] exp(x[i])
314
- # Negating all of this result is the Shanon (discrete) entropy.
306
+ # The following result can be derived as follows. Let s[i] be a logit.
307
+ # The entropy is:
308
+ # H = -sum_i(p[i] * log(p[i]))
309
+ # = -sum_i(p[i] * (s[i] - logsumexp(s))
310
+ # = logsumexp(s) - sum_i(p[i] * s[i])
315
311
logits = tf .convert_to_tensor (self ._logits )
316
- m = tf .reduce_max (logits , axis = - 1 , keepdims = True )
317
- x = logits - m
318
- sum_exp_x = tf .reduce_sum (tf .math .exp (x ), axis = - 1 )
319
- lse_logits = m [..., 0 ] + tf .math .log (sum_exp_x )
312
+ logits = logits - tf .reduce_max (logits , axis = - 1 , keepdims = True )
313
+ lse_logits = tf .reduce_logsumexp (logits , axis = - 1 )
314
+
320
315
# TODO(b/161014180): Workaround to support correct gradient calculations
321
316
# with -inf logits.
322
- is_inf_logits = tf .cast (tf .math .is_inf (logits ), dtype = tf .float32 )
323
- is_negative_logits = tf .cast (logits < 0 , dtype = tf .float32 )
324
317
masked_logits = tf .where (
325
- tf .cast (( is_inf_logits * is_negative_logits ), dtype = bool ),
318
+ ( tf .math . is_inf ( logits ) & ( logits < 0 ) ),
326
319
tf .cast (1.0 , dtype = logits .dtype ), logits )
327
-
328
320
return lse_logits - tf .reduce_sum (
329
- tf .math .multiply_no_nan (masked_logits , tf .math .exp (x )),
330
- axis = - 1 ) / sum_exp_x
321
+ tf .math .multiply_no_nan (masked_logits , tf .math .exp (logits )),
322
+ axis = - 1 ) / tf . math . exp ( lse_logits )
331
323
332
324
def _mode (self ):
333
325
x = self ._probs if self ._logits is None else self ._logits
0 commit comments