Skip to content

Commit 7bd5363

Browse files
Johannes Ballécopybara-github
authored andcommitted
Normalizes power law entropy model penalty differently.
This changes the penalty to be normalized such that it is non-negative, rather than it representing a normalized distribution. This is more intuitive from an optimization perspective, and the only other practical effect this should have is that the effective weight of the penalty changes. PiperOrigin-RevId: 448326753 Change-Id: I4baf22a3446062622d7c643b3040bb6ac22b5fe7
1 parent c89da53 commit 7bd5363

File tree

2 files changed

+30
-30
lines changed

2 files changed

+30
-30
lines changed

tensorflow_compression/python/entropy_models/power_law.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,16 @@
2727
class PowerLawEntropyModel(tf.Module):
2828
"""Entropy model for power-law distributed random variables.
2929
30-
This entropy model handles quantization of a bottleneck tensor and implements
31-
a cross entropy penalty that is consistent with the Elias gamma code.
30+
This entropy model handles quantization and compression of a bottleneck tensor
31+
and implements a penalty that encourages compressibility under the Elias gamma
32+
code.
3233
3334
The gamma code has code lengths `1 + 2 floor(log_2(x))`, for `x` a positive
34-
integer. For details on the gamma code, see:
35+
integer, and is close to optimal if `x` is distributed according to a power
36+
law. Being a universal code, it also guarantees that in the worst case, the
37+
expected code length is no more than 3 times the entropy of the empirical
38+
distribution of `x`, as long as probability decreases with increasing `x`. For
39+
details on the gamma code, see:
3540
3641
> "Universal Codeword Sets and Representations of the Integers"<br />
3742
> P. Elias<br />
@@ -43,13 +48,12 @@ class PowerLawEntropyModel(tf.Module):
4348
4449
The penalty applied by this class is given by:
4550
```
46-
-log_2 p(x), with p(x) = alpha / 2 * (x + alpha) ** -2
51+
log((abs(x) + alpha) / alpha)
4752
```
48-
Like the gamma code, this follows a symmetrized power law, but only
49-
approximately for `alpha > 0`. Without `alpha`, the distribution would not be
50-
normalizable, and the penalty would have a singularity at zero. Setting
51-
`alpha` to a small positive value ensures that the penalty is non-negative,
52-
and that its gradients are useful for optimization.
53+
This encourages `x` to follow a symmetrized power law, but only approximately
54+
for `alpha > 0`. Without `alpha`, the penalty would have a singularity at
55+
zero. Setting `alpha` to a small positive value ensures that the penalty is
56+
non-negative, and that its gradients are useful for optimization.
5357
"""
5458

5559
def __init__(self,
@@ -123,11 +127,7 @@ def penalty(self, bottleneck):
123127
entropy.
124128
"""
125129
bottleneck = tf.convert_to_tensor(bottleneck, dtype=self.bottleneck_dtype)
126-
log_alpha = tf.math.log(
127-
tf.constant(self.alpha, dtype=self.bottleneck_dtype))
128-
log_2 = tf.math.log(tf.constant(2, dtype=self.bottleneck_dtype))
129-
penalty = ((1. - log_alpha / log_2) +
130-
tf.math.log(abs(bottleneck) + self.alpha) * (2. / log_2))
130+
penalty = tf.math.log((abs(bottleneck) + self.alpha) / self.alpha)
131131
return tf.reduce_sum(penalty, axis=tuple(range(-self.coding_rank, 0)))
132132

133133
@tf.Module.with_name_scope

tensorflow_compression/python/entropy_models/power_law_test.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ==============================================================================
1515
"""Tests of power law entropy model."""
1616

17+
import numpy as np
1718
import tensorflow as tf
1819
from tensorflow_compression.python.entropy_models.power_law import PowerLawEntropyModel
1920

@@ -56,27 +57,26 @@ def test_compression_consistent_with_quantization(self):
5657

5758
def test_penalty_is_proportional_to_code_length(self):
5859
em = PowerLawEntropyModel(coding_rank=1)
59-
# Sample some values from a Laplacian distribution.
60-
u = tf.random.uniform((100, 1), minval=-1., maxval=1.)
61-
values = 100. * tf.math.log(abs(u)) * tf.sign(u)
62-
# Ensure there are some large values.
63-
self.assertGreater(tf.reduce_sum(tf.cast(abs(values) > 100, tf.int32)), 0)
64-
strings = em.compress(tf.broadcast_to(values, (100, 100)))
60+
x = tf.range(-20., 20.)[:, None]
61+
x += tf.random.uniform(x.shape, -.49, .49)
62+
strings = em.compress(tf.broadcast_to(x, (40, 100)))
6563
code_lengths = tf.cast(tf.strings.length(strings, unit="BYTE"), tf.float32)
6664
code_lengths *= 8 / 100
67-
penalties = em.penalty(values)
68-
self.assertAllInRange(penalties - code_lengths, 4, 7)
65+
penalties = em.penalty(x)
66+
# There are some fluctuations due to `alpha`, `floor`, and rounding, but we
67+
# expect a high degree of correlation between code lengths and penalty.
68+
self.assertGreater(np.corrcoef(code_lengths, penalties)[0, 1], .96)
6969

70-
def test_penalty_is_differentiable(self):
70+
def test_penalty_is_nonnegative_and_differentiable(self):
7171
em = PowerLawEntropyModel(coding_rank=1)
72-
# Sample some values from a Laplacian distribution.
73-
u = tf.random.uniform((100, 1), minval=-1., maxval=1.)
74-
values = 100. * tf.math.log(abs(u)) * tf.sign(u)
72+
x = tf.range(-20., 20.)[:, None]
73+
x += tf.random.uniform(x.shape, -.49, .49)
7574
with tf.GradientTape() as tape:
76-
tape.watch(values)
77-
penalties = em.penalty(values)
78-
gradients = tape.gradient(penalties, values)
79-
self.assertAllEqual(tf.sign(gradients), tf.sign(values))
75+
tape.watch(x)
76+
penalties = em.penalty(x)
77+
gradients = tape.gradient(penalties, x)
78+
self.assertAllGreaterEqual(penalties, 0)
79+
self.assertAllEqual(tf.sign(gradients), tf.sign(x))
8080

8181
def test_compression_works_in_tf_function(self):
8282
samples = tf.random.stateless_normal([100], (34, 232))

0 commit comments

Comments
 (0)