Skip to content

Commit 77c2621

Browse files
Johannes Ball?copybara-github
authored andcommitted
Implements distribution objects for entropy modeling.
This reimplements the distributions from our existing entropy models in `entropy_models.py` as separate classes. They tie in to the `tensorflow_probability` class hierarchy. The `DeepFactorized` class implements the density model currently implemented in `EntropyBottleneck`. The `*Conditional` classes, which are based on existing standard distributions and modified by adding uniform noise, are now implemented via a `UniformNoiseAdapter` distribution, which takes a base distribution as argument and implements the additive noise modification. It has some similarity to the `QuantizedDistribution` class in `tensorflow_probability`, and some parts of it are based on code I took from there. Making the additive noise modification modular should enable experimentation with a larger variety of density models more easily. PiperOrigin-RevId: 293462740 Change-Id: I275d2c7a3bb8177424546755492be774b5222bf9
1 parent a6eec1d commit 77c2621

File tree

11 files changed

+1032
-0
lines changed

11 files changed

+1032
-0
lines changed

BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ py_library(
1111
srcs_version = "PY3",
1212
visibility = ["//visibility:public"],
1313
deps = [
14+
"//tensorflow_compression/python/distributions",
1415
"//tensorflow_compression/python/layers",
1516
"//tensorflow_compression/python/ops",
1617
"//tensorflow_compression/python/util",
@@ -34,6 +35,7 @@ py_binary(
3435
":pip_src",
3536
":tensorflow_compression",
3637
# The following targets are for Python test files.
38+
"//tensorflow_compression/python/distributions:py_src",
3739
"//tensorflow_compression/python/layers:py_src",
3840
"//tensorflow_compression/python/ops:py_src",
3941
"//tensorflow_compression/python/util:py_src",

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
scipy >= 1
2+
tensorflow_probability >= 0.9

tensorflow_compression/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424

2525

2626
# pylint: disable=wildcard-import
27+
from tensorflow_compression.python.distributions.deep_factorized import *
28+
from tensorflow_compression.python.distributions.helpers import *
29+
from tensorflow_compression.python.distributions.uniform_noise import *
2730
from tensorflow_compression.python.layers.entropy_models import *
2831
from tensorflow_compression.python.layers.gdn import *
2932
from tensorflow_compression.python.layers.initializers import *
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package(
2+
default_visibility = ["//:__subpackages__"],
3+
)
4+
5+
licenses(["notice"]) # Apache 2.0
6+
7+
py_library(
8+
name = "distributions",
9+
srcs = ["__init__.py"],
10+
srcs_version = "PY3",
11+
deps = [
12+
":deep_factorized",
13+
":helpers",
14+
":uniform_noise",
15+
],
16+
)
17+
18+
py_library(
19+
name = "deep_factorized",
20+
srcs = ["deep_factorized.py"],
21+
srcs_version = "PY3",
22+
deps = [":helpers"],
23+
)
24+
25+
py_test(
26+
name = "deep_factorized_test",
27+
srcs = ["deep_factorized_test.py"],
28+
python_version = "PY3",
29+
deps = [
30+
":deep_factorized",
31+
":helpers",
32+
],
33+
)
34+
35+
py_library(
36+
name = "helpers",
37+
srcs = ["helpers.py"],
38+
srcs_version = "PY3",
39+
)
40+
41+
py_test(
42+
name = "helpers_test",
43+
srcs = ["helpers_test.py"],
44+
python_version = "PY3",
45+
deps = [":helpers"],
46+
)
47+
48+
py_library(
49+
name = "uniform_noise",
50+
srcs = ["uniform_noise.py"],
51+
srcs_version = "PY3",
52+
deps = [":helpers"],
53+
)
54+
55+
py_test(
56+
name = "uniform_noise_test",
57+
srcs = ["uniform_noise_test.py"],
58+
python_version = "PY3",
59+
deps = [
60+
":helpers",
61+
":uniform_noise",
62+
],
63+
)
64+
65+
filegroup(
66+
name = "py_src",
67+
srcs = glob(["*.py"]),
68+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2020 Google LLC. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Lint as: python3
2+
# Copyright 2020 Google LLC. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""Deep fully factorized distribution based on cumulative."""
17+
18+
import tensorflow.compat.v2 as tf
19+
import tensorflow_probability as tfp
20+
21+
from tensorflow_compression.python.distributions import helpers
22+
23+
24+
__all__ = ["DeepFactorized"]
25+
26+
27+
class DeepFactorized(tfp.distributions.Distribution):
28+
"""Fully factorized distribution based on neural network cumulative.
29+
30+
This is a flexible, nonparametric probability density model, described in
31+
appendix 6.1 of the paper:
32+
33+
> "Variational image compression with a scale hyperprior"<br />
34+
> J. Ballé, D. Minnen, S. Singh, S. J. Hwang, N. Johnston<br />
35+
> https://openreview.net/forum?id=rkcQFMZRb
36+
37+
This implementation already includes convolution with a unit-width uniform
38+
density, as described in appendix 6.2 of the same paper. Please cite the paper
39+
if you use this code for scientific work.
40+
41+
This is a scalar distribution (i.e., its `event_shape` is always length 0),
42+
and the density object always creates its own `tf.Variable`s representing the
43+
trainable distribution parameters.
44+
"""
45+
46+
def __init__(self, batch_shape=(), num_filters=(3, 3), init_scale=10,
47+
allow_nan_stats=False, dtype=tf.float32, name="DeepFactorized"):
48+
"""Initializer.
49+
50+
Arguments:
51+
batch_shape: Iterable of integers. The desired batch shape for the
52+
`Distribution` (rightmost dimensions which are assumed independent, but
53+
not identically distributed).
54+
num_filters: Iterable of integers. The number of filters for each of the
55+
hidden layers. The first and last layer of the network implementing the
56+
cumulative distribution are not included (they are assumed to be 1).
57+
init_scale: Float. Scale factor for the density at initialization. It is
58+
recommended to choose a large enough scale factor such that most values
59+
initially lie within a region of high likelihood. This improves
60+
training.
61+
allow_nan_stats: Boolean. Whether to allow `NaN`s to be returned when
62+
querying distribution statistics.
63+
dtype: A floating point `tf.dtypes.DType`. Computations relating to this
64+
distribution will be performed at this precision.
65+
name: String. A name for this distribution.
66+
"""
67+
parameters = dict(locals())
68+
self._batch_shape_tuple = tuple(int(s) for s in batch_shape)
69+
self._num_filters = tuple(int(f) for f in num_filters)
70+
self._init_scale = float(init_scale)
71+
self._estimated_tail_mass = None
72+
super().__init__(
73+
dtype=dtype,
74+
reparameterization_type=tfp.distributions.NOT_REPARAMETERIZED,
75+
validate_args=False,
76+
allow_nan_stats=allow_nan_stats,
77+
parameters=parameters,
78+
name=name,
79+
)
80+
self._make_variables()
81+
82+
@property
83+
def num_filters(self):
84+
return self._num_filters
85+
86+
@property
87+
def init_scale(self):
88+
return self._init_scale
89+
90+
def _make_variables(self):
91+
"""Creates the variables representing the parameters of the distribution."""
92+
channels = self.batch_shape.num_elements()
93+
filters = (1,) + self.num_filters + (1,)
94+
scale = self.init_scale ** (1 / (len(self.num_filters) + 1))
95+
self._matrices = []
96+
self._biases = []
97+
self._factors = []
98+
99+
for i in range(len(self.num_filters) + 1):
100+
init = tf.math.log(tf.math.expm1(1 / scale / filters[i + 1]))
101+
init = tf.cast(init, dtype=self.dtype)
102+
init = tf.broadcast_to(init, (channels, filters[i + 1], filters[i]))
103+
matrix = tf.Variable(init, name="matrix_{}".format(i))
104+
self._matrices.append(matrix)
105+
106+
bias = tf.Variable(
107+
tf.random.uniform(
108+
(channels, filters[i + 1], 1), -.5, .5, dtype=self.dtype),
109+
name="bias_{}".format(i))
110+
self._biases.append(bias)
111+
112+
if i < len(self.num_filters):
113+
factor = tf.Variable(
114+
tf.zeros((channels, filters[i + 1], 1), dtype=self.dtype),
115+
name="factor_{}".format(i))
116+
self._factors.append(factor)
117+
118+
def _batch_shape_tensor(self):
119+
return tf.constant(self._batch_shape_tuple, dtype=int)
120+
121+
def _batch_shape(self):
122+
return tf.TensorShape(self._batch_shape_tuple)
123+
124+
def _event_shape_tensor(self):
125+
return tf.constant((), dtype=int)
126+
127+
def _event_shape(self):
128+
return tf.TensorShape(())
129+
130+
def _logits_cumulative(self, inputs):
131+
"""Evaluate logits of the cumulative densities.
132+
133+
Arguments:
134+
inputs: The values at which to evaluate the cumulative densities, expected
135+
to be a `tf.Tensor` of shape `(channels, 1, batch)`.
136+
137+
Returns:
138+
A `tf.Tensor` of the same shape as `inputs`, containing the logits of the
139+
cumulative densities evaluated at the given inputs.
140+
"""
141+
logits = inputs
142+
for i in range(len(self.num_filters) + 1):
143+
matrix = tf.nn.softplus(self._matrices[i])
144+
logits = tf.linalg.matmul(matrix, logits)
145+
logits += self._biases[i]
146+
if i < len(self.num_filters):
147+
factor = tf.math.tanh(self._factors[i])
148+
logits += factor * tf.math.tanh(logits)
149+
return logits
150+
151+
def _prob(self, y):
152+
"""Called by the base class to compute likelihoods."""
153+
# Convert to (channels, 1, batch) format by collapsing dimensions and then
154+
# commuting channels to front.
155+
y = tf.broadcast_to(
156+
y, tf.broadcast_dynamic_shape(tf.shape(y), self.batch_shape_tensor()))
157+
shape = tf.shape(y)
158+
y = tf.reshape(y, (-1, 1, self.batch_shape.num_elements()))
159+
y = tf.transpose(y, (2, 1, 0))
160+
161+
# Evaluate densities.
162+
# We can use the special rule below to only compute differences in the left
163+
# tail of the sigmoid. This increases numerical stability: sigmoid(x) is 1
164+
# for large x, 0 for small x. Subtracting two numbers close to 0 can be done
165+
# with much higher precision than subtracting two numbers close to 1.
166+
lower = self._logits_cumulative(y - .5)
167+
upper = self._logits_cumulative(y + .5)
168+
# Flip signs if we can move more towards the left tail of the sigmoid.
169+
sign = tf.stop_gradient(-tf.math.sign(lower + upper))
170+
p = abs(tf.sigmoid(sign * upper) - tf.sigmoid(sign * lower))
171+
172+
# Convert back to (broadcasted) input tensor shape.
173+
p = tf.transpose(p, (2, 1, 0))
174+
p = tf.reshape(p, shape)
175+
return p
176+
177+
def _quantization_offset(self):
178+
return tf.constant(0, dtype=self.dtype)
179+
180+
def _lower_tail(self, tail_mass):
181+
tail = helpers.estimate_tail(
182+
self._logits_cumulative, -tf.math.log(2 / tail_mass - 1),
183+
[self.batch_shape.num_elements(), 1, 1], self.dtype)
184+
return tf.reshape(tail, self.batch_shape_tensor())
185+
186+
def _upper_tail(self, tail_mass):
187+
tail = helpers.estimate_tail(
188+
self._logits_cumulative, tf.math.log(2 / tail_mass - 1),
189+
[self.batch_shape.num_elements(), 1, 1], self.dtype)
190+
return tf.reshape(tail, self.batch_shape_tensor())
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2020 Google LLC. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests of deep factorized distribution."""
16+
17+
import tensorflow.compat.v2 as tf
18+
import tensorflow_probability as tfp
19+
20+
from tensorflow_compression.python.distributions import deep_factorized
21+
from tensorflow_compression.python.distributions import helpers
22+
23+
24+
class DeepFactorizedTest(tf.test.TestCase):
25+
26+
def test_can_instantiate_scalar(self):
27+
df = deep_factorized.DeepFactorized()
28+
self.assertEqual(df.batch_shape, ())
29+
self.assertEqual(df.event_shape, ())
30+
self.assertEqual(df.num_filters, (3, 3))
31+
self.assertEqual(df.init_scale, 10)
32+
33+
def test_can_instantiate_batched(self):
34+
df = deep_factorized.DeepFactorized(batch_shape=(4, 3))
35+
self.assertEqual(df.batch_shape, (4, 3))
36+
self.assertEqual(df.event_shape, ())
37+
self.assertEqual(df.num_filters, (3, 3))
38+
self.assertEqual(df.init_scale, 10)
39+
40+
def test_variables_receive_gradients(self):
41+
df = deep_factorized.DeepFactorized()
42+
with tf.GradientTape() as tape:
43+
x = tf.random.normal([20])
44+
loss = -tf.reduce_mean(df.log_prob(x))
45+
grads = tape.gradient(loss, df.trainable_variables)
46+
self.assertLen(grads, 8)
47+
self.assertNotIn(None, grads)
48+
49+
def test_logistic_is_special_case(self):
50+
# With no hidden units, the density should collapse to a logistic
51+
# distribution convolved with a standard uniform distribution.
52+
df = deep_factorized.DeepFactorized(num_filters=(), init_scale=1)
53+
logistic = tfp.distributions.Logistic(loc=-df._biases[0][0, 0], scale=1.)
54+
x = tf.linspace(-5., 5., 20)
55+
prob_df = df.prob(x)
56+
prob_log = logistic.cdf(x + .5) - logistic.cdf(x - .5)
57+
self.assertAllClose(prob_df, prob_log)
58+
59+
def test_uniform_is_special_case(self):
60+
# With the scale parameter going to zero, the density should approach a
61+
# unit-width uniform distribution.
62+
df = deep_factorized.DeepFactorized(init_scale=1e-3)
63+
x = tf.linspace(-1., 1., 10)
64+
self.assertAllClose(df.prob(x), [0, 0, 0, 1, 1, 1, 1, 0, 0, 0])
65+
66+
def test_quantization_offset_is_zero(self):
67+
df = deep_factorized.DeepFactorized()
68+
self.assertEqual(helpers.quantization_offset(df), 0)
69+
70+
def test_tails_and_offset_are_in_order(self):
71+
df = deep_factorized.DeepFactorized()
72+
offset = helpers.quantization_offset(df)
73+
lower_tail = helpers.lower_tail(df, 2**-8)
74+
upper_tail = helpers.upper_tail(df, 2**-8)
75+
self.assertGreater(upper_tail, offset)
76+
self.assertGreater(offset, lower_tail)
77+
78+
def test_stats_throw_error(self):
79+
df = deep_factorized.DeepFactorized()
80+
with self.assertRaises(NotImplementedError):
81+
df.mode()
82+
with self.assertRaises(NotImplementedError):
83+
df.mean()
84+
with self.assertRaises(NotImplementedError):
85+
df.quantile(.5)
86+
with self.assertRaises(NotImplementedError):
87+
df.survival_function(.5)
88+
with self.assertRaises(NotImplementedError):
89+
df.sample()
90+
91+
92+
if __name__ == "__main__":
93+
tf.test.main()

0 commit comments

Comments
 (0)