Skip to content

Commit b996d5c

Browse files
Johannes Ball?copybara-github
authored andcommitted
Implements entropy models for continuous-valued random variables.
This reimplements the existing entropy models in `entropy_models.py` in a more modular way. The `EntropyBottleneck` class is now implemented through `ContinuousBatchedEntropyModel` and can be instantiated with arbitrary `Distribution` objects (use `DeepFactorized` to get the equivalent behavior to `EntropyBottleneck`). The old `*Conditional` classes are implemented via `ContinuousIndexedEntropyModel` and the special case `LocationScaleIndexedEntropyModel`. Like the batched version, they can be instantiated with arbitrary distribution objects (use `NoisyNormal` to get the equivalent behavior to `GaussianConditional`, etc.). In addition, `ContinuousIndexedEntropyModel` now also supports multi-dimensional indexing. The use case for that is to model conditionally independent scalar distributions which have more than one parameter that depends on data (e.g., a shape and scale parameter). PiperOrigin-RevId: 293468780 Change-Id: I8189ccd898d8ae33819c36feff7fbc971b8523d0
1 parent 77c2621 commit b996d5c

File tree

9 files changed

+1280
-0
lines changed

9 files changed

+1280
-0
lines changed

BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ py_library(
1212
visibility = ["//visibility:public"],
1313
deps = [
1414
"//tensorflow_compression/python/distributions",
15+
"//tensorflow_compression/python/entropy_models",
1516
"//tensorflow_compression/python/layers",
1617
"//tensorflow_compression/python/ops",
1718
"//tensorflow_compression/python/util",
@@ -36,6 +37,7 @@ py_binary(
3637
":tensorflow_compression",
3738
# The following targets are for Python test files.
3839
"//tensorflow_compression/python/distributions:py_src",
40+
"//tensorflow_compression/python/entropy_models:py_src",
3941
"//tensorflow_compression/python/layers:py_src",
4042
"//tensorflow_compression/python/ops:py_src",
4143
"//tensorflow_compression/python/util:py_src",

tensorflow_compression/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from tensorflow_compression.python.distributions.deep_factorized import *
2828
from tensorflow_compression.python.distributions.helpers import *
2929
from tensorflow_compression.python.distributions.uniform_noise import *
30+
from tensorflow_compression.python.entropy_models.continuous_batched import *
31+
from tensorflow_compression.python.entropy_models.continuous_indexed import *
3032
from tensorflow_compression.python.layers.entropy_models import *
3133
from tensorflow_compression.python.layers.gdn import *
3234
from tensorflow_compression.python.layers.initializers import *
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package(
2+
default_visibility = ["//:__subpackages__"],
3+
)
4+
5+
licenses(["notice"]) # Apache 2.0
6+
7+
py_library(
8+
name = "entropy_models",
9+
srcs = ["__init__.py"],
10+
srcs_version = "PY3",
11+
deps = [
12+
":continuous_base",
13+
":continuous_batched",
14+
":continuous_indexed",
15+
],
16+
)
17+
18+
py_library(
19+
name = "continuous_base",
20+
srcs = ["continuous_base.py"],
21+
srcs_version = "PY3",
22+
deps = [
23+
"//tensorflow_compression/python/distributions:helpers",
24+
"//tensorflow_compression/python/ops:range_coding_ops",
25+
],
26+
)
27+
28+
py_library(
29+
name = "continuous_batched",
30+
srcs = ["continuous_batched.py"],
31+
srcs_version = "PY3",
32+
deps = [
33+
":continuous_base",
34+
"//tensorflow_compression/python/ops:math_ops",
35+
"//tensorflow_compression/python/ops:range_coding_ops",
36+
],
37+
)
38+
39+
py_test(
40+
name = "continuous_batched_test",
41+
srcs = ["continuous_batched_test.py"],
42+
python_version = "PY3",
43+
deps = [
44+
":continuous_batched",
45+
"//tensorflow_compression/python/distributions:uniform_noise",
46+
],
47+
)
48+
49+
py_library(
50+
name = "continuous_indexed",
51+
srcs = ["continuous_indexed.py"],
52+
srcs_version = "PY3",
53+
deps = [
54+
":continuous_base",
55+
"//tensorflow_compression/python/distributions:helpers",
56+
"//tensorflow_compression/python/ops:math_ops",
57+
"//tensorflow_compression/python/ops:range_coding_ops",
58+
],
59+
)
60+
61+
py_test(
62+
name = "continuous_indexed_test",
63+
srcs = ["continuous_indexed_test.py"],
64+
python_version = "PY3",
65+
deps = [
66+
":continuous_indexed",
67+
"//tensorflow_compression/python/distributions:uniform_noise",
68+
],
69+
)
70+
71+
filegroup(
72+
name = "py_src",
73+
srcs = glob(["*.py"]),
74+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2020 Google LLC. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Lint as: python3
2+
# Copyright 2020 Google LLC. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""Base class for continuous entropy models."""
17+
18+
import abc
19+
20+
from absl import logging
21+
import tensorflow.compat.v2 as tf
22+
23+
from tensorflow_compression.python.distributions import helpers
24+
from tensorflow_compression.python.ops import range_coding_ops
25+
26+
27+
__all__ = ["ContinuousEntropyModelBase"]
28+
29+
30+
class ContinuousEntropyModelBase(tf.Module, metaclass=abc.ABCMeta):
31+
"""Base class for continuous entropy models.
32+
33+
The basic functionality of this class is to pre-compute integer probability
34+
tables based on the provided `tfp.distributions.Distribution` object, which
35+
can then be used reliably across different platforms by the range coder and
36+
decoder.
37+
"""
38+
39+
@abc.abstractmethod
40+
def __init__(self, distribution, coding_rank,
41+
likelihood_bound=1e-9, tail_mass=2**-8,
42+
range_coder_precision=12):
43+
"""Initializer.
44+
45+
Arguments:
46+
distribution: A `tfp.distributions.Distribution` object modeling the
47+
distribution of the input data including additive uniform noise. For
48+
best results, the distribution should be flexible enough to have a
49+
unit-width uniform distribution as a special case.
50+
coding_rank: Integer. Number of innermost dimensions considered a coding
51+
unit. Each coding unit is compressed to its own bit string, and the
52+
`bits()` method sums over each coding unit.
53+
likelihood_bound: Float. Lower bound for likelihood values, to prevent
54+
training instabilities.
55+
tail_mass: Float. Approximate probability mass which is range encoded with
56+
less precision, by using a Golomb-like code.
57+
range_coder_precision: Integer. Precision passed to the range coding op.
58+
"""
59+
if not distribution.is_scalar_event():
60+
raise ValueError(
61+
"`distribution` must be a (batch of) scalar distribution(s).")
62+
super().__init__()
63+
self._distribution = distribution
64+
self._coding_rank = int(coding_rank)
65+
self._likelihood_bound = float(likelihood_bound)
66+
self._tail_mass = float(tail_mass)
67+
self._range_coder_precision = int(range_coder_precision)
68+
self.update_tables()
69+
70+
@property
71+
def distribution(self):
72+
"""Distribution modeling data + i.i.d. uniform noise."""
73+
return self._distribution
74+
75+
@property
76+
def coding_rank(self):
77+
"""Number of innermost dimensions considered a coding unit."""
78+
return self._coding_rank
79+
80+
@property
81+
def likelihood_bound(self):
82+
"""Lower bound for likelihood values."""
83+
return self._likelihood_bound
84+
85+
@property
86+
def tail_mass(self):
87+
"""Approximate probability mass which is range encoded with overflow."""
88+
return self._tail_mass
89+
90+
@property
91+
def range_coder_precision(self):
92+
"""Precision passed to range coding op."""
93+
return self._range_coder_precision
94+
95+
@property
96+
def dtype(self):
97+
"""Data type of this distribution."""
98+
return self.distribution.dtype
99+
100+
def quantization_offset(self):
101+
"""Distribution-dependent quantization offset."""
102+
return helpers.quantization_offset(self.distribution)
103+
104+
def lower_tail(self):
105+
"""Approximate lower tail quantile for range coding."""
106+
return helpers.lower_tail(self.distribution, self.tail_mass)
107+
108+
def upper_tail(self):
109+
"""Approximate upper tail quantile for range coding."""
110+
return helpers.upper_tail(self.distribution, self.tail_mass)
111+
112+
@tf.custom_gradient
113+
def _quantize(self, inputs, offset):
114+
return tf.round(inputs - offset) + offset, lambda x: (x, None)
115+
116+
def update_tables(self):
117+
"""Updates integer-valued probability tables used by the range coder.
118+
119+
These tables must not be re-generated independently on the sending and
120+
receiving side, since small numerical discrepancies between both sides can
121+
occur in this process. If the tables differ slightly, this in turn would
122+
very likely cause catastrophic error propagation during range decoding. For
123+
a more in-depth discussion of this, see:
124+
125+
> "Integer Networks for Data Compression with Latent-Variable Models"<br />
126+
> J. Ballé, N. Johnston, D. Minnen<br />
127+
> https://openreview.net/forum?id=S1zz2i0cY7
128+
129+
The tables are stored in `tf.Tensor`s as attributes of this object. The
130+
recommended way is to train the model, then call this method, and then
131+
distribute the model to a sender and a receiver.
132+
"""
133+
offset = self.quantization_offset()
134+
lower_tail = self.lower_tail()
135+
upper_tail = self.upper_tail()
136+
137+
# Largest distance observed between lower tail and median, and between
138+
# median and upper tail.
139+
minima = offset - lower_tail
140+
minima = tf.cast(tf.math.ceil(minima), tf.int32)
141+
minima = tf.math.maximum(minima, 0)
142+
maxima = upper_tail - offset
143+
maxima = tf.cast(tf.math.ceil(maxima), tf.int32)
144+
maxima = tf.math.maximum(maxima, 0)
145+
146+
# PMF starting positions and lengths.
147+
pmf_start = offset - tf.cast(minima, self.dtype)
148+
pmf_length = maxima + minima + 1
149+
150+
# Sample the densities in the computed ranges, possibly computing more
151+
# samples than necessary at the upper end.
152+
max_length = tf.math.reduce_max(pmf_length)
153+
if max_length > 2048:
154+
logging.warning(
155+
"Very wide PMF with %d elements may lead to out of memory issues. "
156+
"Consider encoding distributions with smaller dispersion or "
157+
"increasing `tail_mass` parameter.", int(max_length))
158+
samples = tf.range(tf.cast(max_length, self.dtype), dtype=self.dtype)
159+
samples = tf.reshape(
160+
samples, [-1] + self.distribution.batch_shape.rank * [1])
161+
samples += pmf_start
162+
pmf = self.distribution.prob(samples)
163+
164+
# Collapse batch dimensions of distribution.
165+
pmf = tf.reshape(pmf, [max_length, -1])
166+
pmf = tf.transpose(pmf)
167+
168+
dist_shape = self.distribution.batch_shape_tensor()
169+
pmf_length = tf.broadcast_to(pmf_length, dist_shape)
170+
pmf_length = tf.reshape(pmf_length, [-1])
171+
cdf_length = pmf_length + 2
172+
cdf_offset = tf.broadcast_to(-minima, dist_shape)
173+
cdf_offset = tf.reshape(cdf_offset, [-1])
174+
175+
# Prevent tensors from bouncing back and forth between host and GPU.
176+
with tf.device("/cpu:0"):
177+
def loop_body(args):
178+
prob, length = args
179+
prob = prob[:length]
180+
prob = tf.concat([prob, 1 - tf.reduce_sum(prob, keepdims=True)], axis=0)
181+
cdf = range_coding_ops.pmf_to_quantized_cdf(
182+
prob, precision=self.range_coder_precision)
183+
return tf.pad(
184+
cdf, [[0, max_length - length]], mode="CONSTANT", constant_values=0)
185+
186+
# TODO(jonycgn,ssjhv): Consider switching to Python control flow.
187+
cdf = tf.map_fn(
188+
loop_body, (pmf, pmf_length), dtype=tf.int32, name="pmf_to_cdf")
189+
190+
self._cdf, self._cdf_offset, self._cdf_length = cdf, cdf_offset, cdf_length

0 commit comments

Comments
 (0)