Skip to content

Commit c89da53

Browse files
Johannes Ballécopybara-github
authored andcommitted
Adds power law entropy model for use with run-length gamma code.
PiperOrigin-RevId: 447742176 Change-Id: Ib35819497aed4fd5e34a9c03312e08783b1c1a75
1 parent 6d1f106 commit c89da53

File tree

6 files changed

+352
-0
lines changed

6 files changed

+352
-0
lines changed

tensorflow_compression/all_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from tensorflow_compression.python.entropy_models.continuous_batched_test import *
3131
from tensorflow_compression.python.entropy_models.continuous_indexed_test import *
32+
from tensorflow_compression.python.entropy_models.power_law_test import *
3233
from tensorflow_compression.python.entropy_models.universal_test import *
3334

3435
from tensorflow_compression.python.layers.gdn_test import *

tensorflow_compression/python/entropy_models/BUILD

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ py_library(
1010
deps = [
1111
":continuous_batched",
1212
":continuous_indexed",
13+
":power_law",
1314
":universal",
1415
],
1516
)
@@ -66,6 +67,21 @@ py_test(
6667
],
6768
)
6869

70+
py_library(
71+
name = "power_law",
72+
srcs = ["power_law.py"],
73+
deps = [
74+
"//tensorflow_compression/python/ops:gen_ops",
75+
"//tensorflow_compression/python/ops:round_ops",
76+
],
77+
)
78+
79+
py_test(
80+
name = "power_law_test",
81+
srcs = ["power_law_test.py"],
82+
deps = [":power_law"],
83+
)
84+
6985
py_library(
7086
name = "universal",
7187
srcs = ["universal.py"],

tensorflow_compression/python/entropy_models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@
1616

1717
from tensorflow_compression.python.entropy_models.continuous_batched import *
1818
from tensorflow_compression.python.entropy_models.continuous_indexed import *
19+
from tensorflow_compression.python.entropy_models.power_law import *
1920
from tensorflow_compression.python.entropy_models.universal import *
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright 2022 Google LLC. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""An entropy model for the run-length gamma code."""
16+
17+
import tensorflow as tf
18+
from tensorflow_compression.python.ops import gen_ops
19+
from tensorflow_compression.python.ops import round_ops
20+
21+
22+
__all__ = [
23+
"PowerLawEntropyModel",
24+
]
25+
26+
27+
class PowerLawEntropyModel(tf.Module):
28+
"""Entropy model for power-law distributed random variables.
29+
30+
This entropy model handles quantization of a bottleneck tensor and implements
31+
a cross entropy penalty that is consistent with the Elias gamma code.
32+
33+
The gamma code has code lengths `1 + 2 floor(log_2(x))`, for `x` a positive
34+
integer. For details on the gamma code, see:
35+
36+
> "Universal Codeword Sets and Representations of the Integers"<br />
37+
> P. Elias<br />
38+
> https://doi.org/10.1109/TIT.1975.1055349
39+
40+
Given a signed integer, `run_length_gamma_encode` encodes zeros using a
41+
run-length code, the sign using a uniform bit, and applies the gamma code to
42+
the magnitude.
43+
44+
The penalty applied by this class is given by:
45+
```
46+
-log_2 p(x), with p(x) = alpha / 2 * (x + alpha) ** -2
47+
```
48+
Like the gamma code, this follows a symmetrized power law, but only
49+
approximately for `alpha > 0`. Without `alpha`, the distribution would not be
50+
normalizable, and the penalty would have a singularity at zero. Setting
51+
`alpha` to a small positive value ensures that the penalty is non-negative,
52+
and that its gradients are useful for optimization.
53+
"""
54+
55+
def __init__(self,
56+
coding_rank,
57+
alpha=1e-2,
58+
bottleneck_dtype=None):
59+
"""Initializes the instance.
60+
61+
Args:
62+
coding_rank: Integer. Number of innermost dimensions considered a coding
63+
unit. Each coding unit is compressed to its own bit string, and the
64+
estimated rate is summed over each coding unit in `bits()`.
65+
alpha: Float. Regularization parameter preventing gradient singularity
66+
around zero.
67+
bottleneck_dtype: `tf.dtypes.DType`. Data type of bottleneck tensor.
68+
Defaults to `tf.keras.mixed_precision.global_policy().compute_dtype`.
69+
"""
70+
self._coding_rank = int(coding_rank)
71+
if self.coding_rank < 0:
72+
raise ValueError("`coding_rank` must be at least 0.")
73+
self._alpha = float(alpha)
74+
if self.alpha <= 0:
75+
raise ValueError("`alpha` must be greater than 0.")
76+
if bottleneck_dtype is None:
77+
bottleneck_dtype = tf.keras.mixed_precision.global_policy().compute_dtype
78+
if bottleneck_dtype is None:
79+
bottleneck_dtype = tf.keras.backend.floatx()
80+
self._bottleneck_dtype = tf.as_dtype(bottleneck_dtype)
81+
super().__init__()
82+
83+
@property
84+
def alpha(self):
85+
"""Alpha parameter."""
86+
return self._alpha
87+
88+
@property
89+
def bottleneck_dtype(self):
90+
"""Data type of the bottleneck tensor."""
91+
return self._bottleneck_dtype
92+
93+
@property
94+
def coding_rank(self):
95+
"""Number of innermost dimensions considered a coding unit."""
96+
return self._coding_rank
97+
98+
@tf.Module.with_name_scope
99+
def __call__(self, bottleneck):
100+
"""Perturbs a tensor with (quantization) noise and computes penalty.
101+
102+
Args:
103+
bottleneck: `tf.Tensor` containing the data to be compressed. Must have at
104+
least `self.coding_rank` dimensions.
105+
106+
Returns:
107+
A tuple `(self.quantize(bottleneck), self.penalty(bottleneck))`.
108+
"""
109+
bottleneck = tf.convert_to_tensor(bottleneck, dtype=self.bottleneck_dtype)
110+
return self.quantize(bottleneck), self.penalty(bottleneck)
111+
112+
@tf.Module.with_name_scope
113+
def penalty(self, bottleneck):
114+
"""Computes cross-entropy penalty.
115+
116+
Args:
117+
bottleneck: `tf.Tensor` containing the data to be compressed. Must have at
118+
least `self.coding_rank` dimensions.
119+
120+
Returns:
121+
Penalty, which has the same shape as `bottleneck` without the
122+
`self.coding_rank` innermost dimensions, and corresponds to a cross
123+
entropy.
124+
"""
125+
bottleneck = tf.convert_to_tensor(bottleneck, dtype=self.bottleneck_dtype)
126+
log_alpha = tf.math.log(
127+
tf.constant(self.alpha, dtype=self.bottleneck_dtype))
128+
log_2 = tf.math.log(tf.constant(2, dtype=self.bottleneck_dtype))
129+
penalty = ((1. - log_alpha / log_2) +
130+
tf.math.log(abs(bottleneck) + self.alpha) * (2. / log_2))
131+
return tf.reduce_sum(penalty, axis=tuple(range(-self.coding_rank, 0)))
132+
133+
@tf.Module.with_name_scope
134+
def quantize(self, bottleneck):
135+
"""Quantizes a floating-point bottleneck tensor.
136+
137+
The tensor is rounded to integer values. The gradient of this rounding
138+
operation is overridden with the identity (straight-through gradient
139+
estimator).
140+
141+
Args:
142+
bottleneck: `tf.Tensor` containing the data to be quantized.
143+
144+
Returns:
145+
A `tf.Tensor` containing the quantized values.
146+
"""
147+
bottleneck = tf.convert_to_tensor(bottleneck, dtype=self.bottleneck_dtype)
148+
return round_ops.round_st(bottleneck)
149+
150+
@tf.Module.with_name_scope
151+
def compress(self, bottleneck):
152+
"""Compresses a floating-point tensor.
153+
154+
Compresses the tensor to bit strings. `bottleneck` is first quantized
155+
as in `quantize()`, and then compressed using the run-length gamma code. The
156+
quantized tensor can later be recovered by calling `decompress()`.
157+
158+
The innermost `self.coding_rank` dimensions are treated as one coding unit,
159+
i.e. are compressed into one string each. Any additional dimensions to the
160+
left are treated as batch dimensions.
161+
162+
Args:
163+
bottleneck: `tf.Tensor` containing the data to be compressed. Must have at
164+
least `self.coding_rank` dimensions.
165+
166+
Returns:
167+
A `tf.Tensor` having the same shape as `bottleneck` without the
168+
`self.coding_rank` innermost dimensions, containing a string for each
169+
coding unit.
170+
"""
171+
bottleneck = tf.convert_to_tensor(bottleneck, dtype=self.bottleneck_dtype)
172+
173+
shape = tf.shape(bottleneck)
174+
if self.coding_rank == 0:
175+
flat_shape = [-1]
176+
strings_shape = shape
177+
else:
178+
flat_shape = tf.concat([[-1], shape[-self.coding_rank:]], 0)
179+
strings_shape = shape[:-self.coding_rank]
180+
181+
symbols = tf.cast(tf.round(bottleneck), tf.int32)
182+
symbols = tf.reshape(symbols, flat_shape)
183+
184+
strings = tf.map_fn(
185+
gen_ops.run_length_gamma_encode, symbols,
186+
fn_output_signature=tf.TensorSpec((), dtype=tf.string))
187+
return tf.reshape(strings, strings_shape)
188+
189+
@tf.Module.with_name_scope
190+
def decompress(self, strings, code_shape):
191+
"""Decompresses a tensor.
192+
193+
Reconstructs the quantized tensor from bit strings produced by `compress()`.
194+
195+
Args:
196+
strings: `tf.Tensor` containing the compressed bit strings.
197+
code_shape: Shape of innermost dimensions of the output `tf.Tensor`.
198+
199+
Returns:
200+
A `tf.Tensor` of shape `tf.shape(strings) + code_shape`.
201+
"""
202+
strings = tf.convert_to_tensor(strings, dtype=tf.string)
203+
strings_shape = tf.shape(strings)
204+
symbols = tf.map_fn(
205+
lambda x: gen_ops.run_length_gamma_decode(x, code_shape),
206+
tf.reshape(strings, [-1]),
207+
fn_output_signature=tf.TensorSpec(
208+
[None] * self.coding_rank, dtype=tf.int32))
209+
symbols = tf.reshape(symbols, tf.concat([strings_shape, code_shape], 0))
210+
return tf.cast(symbols, self.bottleneck_dtype)
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2022 Google LLC. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests of power law entropy model."""
16+
17+
import tensorflow as tf
18+
from tensorflow_compression.python.entropy_models.power_law import PowerLawEntropyModel
19+
20+
21+
class PowerLawEntropyModelTest(tf.test.TestCase):
22+
23+
def test_can_instantiate(self):
24+
em = PowerLawEntropyModel(coding_rank=1)
25+
self.assertEqual(em.coding_rank, 1)
26+
self.assertEqual(em.bottleneck_dtype, tf.float32)
27+
28+
def test_requires_coding_rank_greater_equal_zero(self):
29+
with self.assertRaises(ValueError):
30+
PowerLawEntropyModel(coding_rank=-1)
31+
32+
def test_quantizes_to_integers(self):
33+
em = PowerLawEntropyModel(coding_rank=1)
34+
x = tf.range(-20., 20.)
35+
x_perturbed = x + tf.random.uniform(x.shape, -.49, .49)
36+
x_quantized = em.quantize(x_perturbed)
37+
self.assertAllEqual(x, x_quantized)
38+
39+
def test_gradients_are_straight_through(self):
40+
em = PowerLawEntropyModel(coding_rank=1)
41+
x = tf.range(-20., 20.)
42+
x_perturbed = x + tf.random.uniform(x.shape, -.49, .49)
43+
with tf.GradientTape() as tape:
44+
tape.watch(x_perturbed)
45+
x_quantized = em.quantize(x_perturbed)
46+
gradients = tape.gradient(x_quantized, x_perturbed)
47+
self.assertAllEqual(gradients, tf.ones_like(gradients))
48+
49+
def test_compression_consistent_with_quantization(self):
50+
em = PowerLawEntropyModel(coding_rank=1)
51+
x = tf.range(-20., 20.)
52+
x += tf.random.uniform(x.shape, -.49, .49)
53+
x_quantized = em.quantize(x)
54+
x_decompressed = em.decompress(em.compress(x), x.shape)
55+
self.assertAllEqual(x_decompressed, x_quantized)
56+
57+
def test_penalty_is_proportional_to_code_length(self):
58+
em = PowerLawEntropyModel(coding_rank=1)
59+
# Sample some values from a Laplacian distribution.
60+
u = tf.random.uniform((100, 1), minval=-1., maxval=1.)
61+
values = 100. * tf.math.log(abs(u)) * tf.sign(u)
62+
# Ensure there are some large values.
63+
self.assertGreater(tf.reduce_sum(tf.cast(abs(values) > 100, tf.int32)), 0)
64+
strings = em.compress(tf.broadcast_to(values, (100, 100)))
65+
code_lengths = tf.cast(tf.strings.length(strings, unit="BYTE"), tf.float32)
66+
code_lengths *= 8 / 100
67+
penalties = em.penalty(values)
68+
self.assertAllInRange(penalties - code_lengths, 4, 7)
69+
70+
def test_penalty_is_differentiable(self):
71+
em = PowerLawEntropyModel(coding_rank=1)
72+
# Sample some values from a Laplacian distribution.
73+
u = tf.random.uniform((100, 1), minval=-1., maxval=1.)
74+
values = 100. * tf.math.log(abs(u)) * tf.sign(u)
75+
with tf.GradientTape() as tape:
76+
tape.watch(values)
77+
penalties = em.penalty(values)
78+
gradients = tape.gradient(penalties, values)
79+
self.assertAllEqual(tf.sign(gradients), tf.sign(values))
80+
81+
def test_compression_works_in_tf_function(self):
82+
samples = tf.random.stateless_normal([100], (34, 232))
83+
84+
# Since tf.function traces each function twice, and only allows variable
85+
# creation in the first call, we need to have a stateful object in which we
86+
# create the entropy model only the first time the function is called, and
87+
# store it for the second time.
88+
89+
class Compressor:
90+
91+
def compress(self, values):
92+
if not hasattr(self, "em"):
93+
self.em = PowerLawEntropyModel(coding_rank=1)
94+
compressed = self.em.compress(values)
95+
return self.em.decompress(compressed, [100])
96+
97+
values_eager = Compressor().compress(samples)
98+
values_function = tf.function(Compressor().compress)(samples)
99+
self.assertAllClose(samples, values_eager, rtol=0., atol=.5)
100+
self.assertAllEqual(values_eager, values_function)
101+
102+
def test_dtypes_are_correct_with_mixed_precision(self):
103+
tf.keras.mixed_precision.set_global_policy("mixed_float16")
104+
try:
105+
em = PowerLawEntropyModel(coding_rank=1)
106+
self.assertEqual(em.bottleneck_dtype, tf.float16)
107+
x = tf.random.stateless_normal((2, 5), seed=(0, 1), dtype=tf.float16)
108+
x_tilde, penalty = em(x)
109+
bitstring = em.compress(x)
110+
x_hat = em.decompress(bitstring, (5,))
111+
self.assertEqual(x_hat.dtype, tf.float16)
112+
self.assertAllClose(x, x_hat, rtol=0, atol=.5)
113+
self.assertEqual(x_tilde.dtype, tf.float16)
114+
self.assertAllClose(x, x_tilde, rtol=0, atol=.5)
115+
self.assertEqual(penalty.dtype, tf.float16)
116+
self.assertEqual(penalty.shape, (2,))
117+
finally:
118+
tf.keras.mixed_precision.set_global_policy(None)
119+
120+
121+
if __name__ == "__main__":
122+
tf.test.main()

tensorflow_compression/python/entropy_models/universal.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# Copyright 2020 Google LLC. All Rights Reserved.
2+
#
13
# Licensed under the Apache License, Version 2.0 (the "License");
24
# you may not use this file except in compliance with the License.
35
# You may obtain a copy of the License at

0 commit comments

Comments
 (0)