Skip to content

Commit fa8ac89

Browse files
relationalcopybara-github
authored andcommitted
Split DeepFactorized into DeepFactorized + NoisyDeepFactorized.
PiperOrigin-RevId: 332792359 Change-Id: I5141a48eae32908c2793fda70b87811e01d1324a
1 parent 7654443 commit fa8ac89

File tree

4 files changed

+184
-68
lines changed

4 files changed

+184
-68
lines changed

tensorflow_compression/python/distributions/BUILD

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,7 @@ py_library(
1919
name = "deep_factorized",
2020
srcs = ["deep_factorized.py"],
2121
srcs_version = "PY3",
22-
deps = [
23-
":helpers",
24-
"//tensorflow_compression/python/ops:math_ops",
25-
],
22+
deps = [":uniform_noise"],
2623
)
2724

2825
py_test(

tensorflow_compression/python/distributions/deep_factorized.py

Lines changed: 99 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,31 @@
1717
import tensorflow.compat.v2 as tf
1818
import tensorflow_probability as tfp
1919

20-
from tensorflow_compression.python.distributions import helpers
21-
from tensorflow_compression.python.ops import math_ops
20+
from tensorflow_compression.python.distributions import uniform_noise
2221

2322

24-
__all__ = ["DeepFactorized"]
23+
__all__ = ["DeepFactorized", "NoisyDeepFactorized"]
24+
25+
26+
def log_expm1(x):
27+
"""Computes log(exp(x)-1) stably.
28+
29+
For large values of x, exp(x) will return Inf whereas log(exp(x)-1) ~= x.
30+
Here we use this approximation for x>15, such that the output is non-Inf for
31+
all positive values x.
32+
33+
Args:
34+
x: A tensor.
35+
36+
Returns:
37+
log(exp(x)-1)
38+
39+
"""
40+
# If x<15.0, we can compute it directly. For larger values,
41+
# we have log(exp(x)-1) ~= log(exp(x)) = x.
42+
cond = (x < 15.0)
43+
x_small = tf.minimum(x, 15.0)
44+
return tf.where(cond, tf.math.log(tf.math.expm1(x_small)), x)
2545

2646

2747
class DeepFactorized(tfp.distributions.Distribution):
@@ -34,7 +54,7 @@ class DeepFactorized(tfp.distributions.Distribution):
3454
> J. Ballé, D. Minnen, S. Singh, S. J. Hwang, N. Johnston<br />
3555
> https://openreview.net/forum?id=rkcQFMZRb
3656
37-
This implementation already includes convolution with a unit-width uniform
57+
but *without* convolution with a unit-width uniform
3858
density, as described in appendix 6.2 of the same paper. Please cite the paper
3959
if you use this code for scientific work.
4060
@@ -43,7 +63,8 @@ class DeepFactorized(tfp.distributions.Distribution):
4363
trainable distribution parameters.
4464
"""
4565

46-
def __init__(self, batch_shape=(), num_filters=(3, 3), init_scale=10,
66+
def __init__(self,
67+
batch_shape=(), num_filters=(3, 3), init_scale=10,
4768
allow_nan_stats=False, dtype=tf.float32, name="DeepFactorized"):
4869
"""Initializer.
4970
@@ -98,22 +119,31 @@ def _make_variables(self):
98119
self._factors = []
99120

100121
for i in range(len(self.num_filters) + 1):
101-
init = tf.math.log(tf.math.expm1(1 / scale / filters[i + 1]))
102-
init = tf.cast(init, dtype=self.dtype)
103-
init = tf.broadcast_to(init, (channels, filters[i + 1], filters[i]))
104-
matrix = tf.Variable(init, name="matrix_{}".format(i))
122+
123+
def matrix_initializer(i=i):
124+
init = log_expm1(1 / scale / filters[i + 1])
125+
init = tf.cast(init, dtype=self.dtype)
126+
init = tf.broadcast_to(init, (channels, filters[i + 1], filters[i]))
127+
return init
128+
129+
matrix = tf.Variable(matrix_initializer, name="matrix_{}".format(i))
105130
self._matrices.append(matrix)
106131

107-
bias = tf.Variable(
108-
tf.random.uniform(
109-
(channels, filters[i + 1], 1), -.5, .5, dtype=self.dtype),
110-
name="bias_{}".format(i))
132+
def bias_initializer(i=i):
133+
return tf.random.uniform((channels, filters[i + 1], 1),
134+
-.5,
135+
.5,
136+
dtype=self.dtype)
137+
138+
bias = tf.Variable(bias_initializer, name="bias_{}".format(i))
111139
self._biases.append(bias)
112140

113141
if i < len(self.num_filters):
114-
factor = tf.Variable(
115-
tf.zeros((channels, filters[i + 1], 1), dtype=self.dtype),
116-
name="factor_{}".format(i))
142+
143+
def factor_initializer(i=i):
144+
return tf.zeros((channels, filters[i + 1], 1), dtype=self.dtype)
145+
146+
factor = tf.Variable(factor_initializer, name="factor_{}".format(i))
117147
self._factors.append(factor)
118148

119149
def _batch_shape_tensor(self):
@@ -132,13 +162,20 @@ def _logits_cumulative(self, inputs):
132162
"""Evaluate logits of the cumulative densities.
133163
134164
Arguments:
135-
inputs: The values at which to evaluate the cumulative densities, expected
136-
to be a `tf.Tensor` of shape `(channels, 1, batch)`.
165+
inputs: The values at which to evaluate the cumulative densities.
137166
138167
Returns:
139168
A `tf.Tensor` of the same shape as `inputs`, containing the logits of the
140169
cumulative densities evaluated at the given inputs.
141170
"""
171+
# Convert to (channels, 1, batch) format by collapsing dimensions and then
172+
# commuting channels to front.
173+
inputs = tf.broadcast_to(
174+
inputs,
175+
tf.broadcast_dynamic_shape(tf.shape(inputs), self.batch_shape_tensor()))
176+
shape = tf.shape(inputs)
177+
inputs = tf.reshape(inputs, (-1, 1, self.batch_shape.num_elements()))
178+
inputs = tf.transpose(inputs, (2, 1, 0))
142179
logits = inputs
143180
for i in range(len(self.num_filters) + 1):
144181
matrix = tf.nn.softplus(self._matrices[i])
@@ -147,48 +184,53 @@ def _logits_cumulative(self, inputs):
147184
if i < len(self.num_filters):
148185
factor = tf.math.tanh(self._factors[i])
149186
logits += factor * tf.math.tanh(logits)
150-
return logits
151-
152-
def _prob(self, y):
153-
"""Called by the base class to compute likelihoods."""
154-
# Convert to (channels, 1, batch) format by collapsing dimensions and then
155-
# commuting channels to front.
156-
y = tf.broadcast_to(
157-
y, tf.broadcast_dynamic_shape(tf.shape(y), self.batch_shape_tensor()))
158-
shape = tf.shape(y)
159-
y = tf.reshape(y, (-1, 1, self.batch_shape.num_elements()))
160-
y = tf.transpose(y, (2, 1, 0))
161-
162-
# Evaluate densities.
163-
# We can use the special rule below to only compute differences in the left
164-
# tail of the sigmoid. This increases numerical stability: sigmoid(x) is 1
165-
# for large x, 0 for small x. Subtracting two numbers close to 0 can be done
166-
# with much higher precision than subtracting two numbers close to 1.
167-
lower = self._logits_cumulative(y - .5)
168-
upper = self._logits_cumulative(y + .5)
169-
# Flip signs if we can move more towards the left tail of the sigmoid.
170-
sign = tf.stop_gradient(-tf.math.sign(lower + upper))
171-
p = abs(tf.sigmoid(sign * upper) - tf.sigmoid(sign * lower))
172-
p = math_ops.lower_bound(p, 0.)
173187

174188
# Convert back to (broadcasted) input tensor shape.
175-
p = tf.transpose(p, (2, 1, 0))
176-
p = tf.reshape(p, shape)
177-
return p
189+
logits = tf.transpose(logits, (2, 1, 0))
190+
logits = tf.reshape(logits, shape)
191+
return logits
192+
193+
def _log_cdf(self, inputs):
194+
logits = self._logits_cumulative(inputs)
195+
return tf.math.log_sigmoid(logits)
196+
197+
def _log_survival_function(self, inputs):
198+
logits = self._logits_cumulative(inputs)
199+
# 1-sigmoid(x) = sigmoid(-x)
200+
return tf.math.log_sigmoid(-logits)
201+
202+
def _cdf(self, inputs):
203+
logits = self._logits_cumulative(inputs)
204+
return tf.math.sigmoid(logits)
205+
206+
def _prob(self, inputs):
207+
with tf.GradientTape() as tape:
208+
tape.watch(inputs)
209+
cdf = self._cdf(inputs)
210+
prob = tape.gradient(cdf, inputs)
211+
return prob
212+
213+
def _log_prob(self, inputs):
214+
# let x=inputs and s(x)=sigmoid(x).
215+
with tf.GradientTape() as tape:
216+
tape.watch(inputs)
217+
logits = self._logits_cumulative(inputs)
218+
# We have F(x) = s(logits(x))
219+
# so p(x) = F'(x)
220+
# = s'(logits(x)) * logits'(x)
221+
# = s(logits(x))*s(-logits(x)) * logits'(x)
222+
# so log p(x) = log(s(logits(x)) + log(s(-logits(x)) + log(logits'(x))
223+
log_s_logits = tf.math.log_sigmoid(logits)
224+
log_s_neg_logits = tf.math.log_sigmoid(-logits)
225+
dlogits = tape.gradient(logits, inputs)
226+
return log_s_logits + log_s_neg_logits + tf.math.log(dlogits)
178227

179228
def _quantization_offset(self):
180229
return tf.constant(0, dtype=self.dtype)
181230

182-
def _lower_tail(self, tail_mass):
183-
tail = helpers.estimate_tails(
184-
self._logits_cumulative, -tf.math.log(2 / tail_mass - 1),
185-
tf.constant([self.batch_shape.num_elements(), 1, 1], tf.int32),
186-
self.dtype)
187-
return tf.reshape(tail, self.batch_shape_tensor())
188-
189-
def _upper_tail(self, tail_mass):
190-
tail = helpers.estimate_tails(
191-
self._logits_cumulative, tf.math.log(2 / tail_mass - 1),
192-
tf.constant([self.batch_shape.num_elements(), 1, 1], tf.int32),
193-
self.dtype)
194-
return tf.reshape(tail, self.batch_shape_tensor())
231+
232+
class NoisyDeepFactorized(uniform_noise.UniformNoiseAdapter):
233+
"""DeepFactorized that is convolved with uniform noise."""
234+
235+
def __init__(self, name="NoisyDeepFactorized", **kwargs):
236+
super().__init__(DeepFactorized(**kwargs), name=name)

tensorflow_compression/python/distributions/deep_factorized_test.py

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,80 @@ def test_can_instantiate_batched(self):
3737
self.assertEqual(df.num_filters, (3, 3))
3838
self.assertEqual(df.init_scale, 10)
3939

40+
def test_logistic_is_special_case_prob(self):
41+
# With no hidden units, the density should collapse to a logistic
42+
# distribution.
43+
df = deep_factorized.DeepFactorized(num_filters=(), init_scale=1)
44+
logistic = tfp.distributions.Logistic(loc=-df._biases[0][0, 0], scale=1.)
45+
x = tf.linspace(-5., 5., 20)
46+
prob_df = df.prob(x)
47+
prob_logistic = logistic.prob(x)
48+
self.assertAllClose(prob_df, prob_logistic)
49+
50+
def test_logistic_is_special_case_cdf(self):
51+
# With no hidden units, the density should collapse to a logistic
52+
# distribution.
53+
df = deep_factorized.DeepFactorized(num_filters=(), init_scale=1)
54+
logistic = tfp.distributions.Logistic(loc=-df._biases[0][0, 0], scale=1.)
55+
x = tf.linspace(-5., 5., 20)
56+
cdf_df = df.cdf(x)
57+
cdf_logistic = logistic.cdf(x)
58+
self.assertAllClose(cdf_df, cdf_logistic)
59+
60+
def test_logistic_is_special_case_log_prob(self):
61+
# With no hidden units, the density should collapse to a logistic
62+
# distribution.
63+
df = deep_factorized.DeepFactorized(num_filters=(), init_scale=1)
64+
logistic = tfp.distributions.Logistic(loc=-df._biases[0][0, 0], scale=1.)
65+
x = tf.linspace(-5000., 5000., 1000)
66+
log_prob_df = df.log_prob(x)
67+
log_prob_logistic = logistic.log_prob(x)
68+
self.assertAllClose(log_prob_df, log_prob_logistic)
69+
70+
def test_logistic_is_special_case_log_cdf(self):
71+
# With no hidden units, the density should collapse to a logistic
72+
# distribution.
73+
df = deep_factorized.DeepFactorized(num_filters=(), init_scale=1)
74+
logistic = tfp.distributions.Logistic(loc=-df._biases[0][0, 0], scale=1.)
75+
x = tf.linspace(-5000., 5000., 1000)
76+
log_cdf_df = df.log_cdf(x)
77+
log_cdf_logistic = logistic.log_cdf(x)
78+
self.assertAllClose(log_cdf_df, log_cdf_logistic)
79+
80+
def test_logistic_is_special_case_log_survival_function(self):
81+
# With no hidden units, the density should collapse to a logistic
82+
# distribution.
83+
df = deep_factorized.DeepFactorized(num_filters=(), init_scale=1)
84+
logistic = tfp.distributions.Logistic(loc=-df._biases[0][0, 0], scale=1.)
85+
x = tf.linspace(-5000., 5000., 1000)
86+
log_survival_function_df = df.log_survival_function(x)
87+
log_survival_function_logistic = logistic.log_survival_function(x)
88+
self.assertAllClose(log_survival_function_df,
89+
log_survival_function_logistic)
90+
91+
92+
class NoisyDeepFactorizedTest(tf.test.TestCase):
93+
94+
def test_can_instantiate_and_run_scalar(self):
95+
df = deep_factorized.NoisyDeepFactorized(num_filters=(2, 3, 4))
96+
self.assertEqual(df.batch_shape, ())
97+
self.assertEqual(df.event_shape, ())
98+
self.assertEqual(df.base.num_filters, (2, 3, 4))
99+
self.assertEqual(df.base.init_scale, 10)
100+
x = tf.random.normal((10,))
101+
df.prob(x)
102+
103+
def test_can_instantiate_and_run_batched(self):
104+
df = deep_factorized.NoisyDeepFactorized(batch_shape=(4, 3))
105+
self.assertEqual(df.batch_shape, (4, 3))
106+
self.assertEqual(df.event_shape, ())
107+
self.assertEqual(df.base.num_filters, (3, 3))
108+
self.assertEqual(df.base.init_scale, 10)
109+
x = tf.random.normal((10, 4, 3))
110+
df.prob(x)
111+
40112
def test_variables_receive_gradients(self):
41-
df = deep_factorized.DeepFactorized()
113+
df = deep_factorized.NoisyDeepFactorized()
42114
with tf.GradientTape() as tape:
43115
x = tf.random.normal([20])
44116
loss = -tf.reduce_mean(df.log_prob(x))
@@ -49,8 +121,9 @@ def test_variables_receive_gradients(self):
49121
def test_logistic_is_special_case(self):
50122
# With no hidden units, the density should collapse to a logistic
51123
# distribution convolved with a standard uniform distribution.
52-
df = deep_factorized.DeepFactorized(num_filters=(), init_scale=1)
53-
logistic = tfp.distributions.Logistic(loc=-df._biases[0][0, 0], scale=1.)
124+
df = deep_factorized.NoisyDeepFactorized(num_filters=(), init_scale=1)
125+
logistic = tfp.distributions.Logistic(loc=-df.base._biases[0][0, 0],
126+
scale=1.)
54127
x = tf.linspace(-5., 5., 20)
55128
prob_df = df.prob(x)
56129
prob_log = logistic.cdf(x + .5) - logistic.cdf(x - .5)
@@ -59,24 +132,24 @@ def test_logistic_is_special_case(self):
59132
def test_uniform_is_special_case(self):
60133
# With the scale parameter going to zero, the density should approach a
61134
# unit-width uniform distribution.
62-
df = deep_factorized.DeepFactorized(init_scale=1e-3)
135+
df = deep_factorized.NoisyDeepFactorized(init_scale=1e-3)
63136
x = tf.linspace(-1., 1., 10)
64137
self.assertAllClose(df.prob(x), [0, 0, 0, 1, 1, 1, 1, 0, 0, 0])
65138

66139
def test_quantization_offset_is_zero(self):
67-
df = deep_factorized.DeepFactorized()
140+
df = deep_factorized.NoisyDeepFactorized()
68141
self.assertEqual(helpers.quantization_offset(df), 0)
69142

70143
def test_tails_and_offset_are_in_order(self):
71-
df = deep_factorized.DeepFactorized()
144+
df = deep_factorized.NoisyDeepFactorized()
72145
offset = helpers.quantization_offset(df)
73146
lower_tail = helpers.lower_tail(df, 2**-8)
74147
upper_tail = helpers.upper_tail(df, 2**-8)
75148
self.assertGreater(upper_tail, offset)
76149
self.assertGreater(offset, lower_tail)
77150

78151
def test_stats_throw_error(self):
79-
df = deep_factorized.DeepFactorized()
152+
df = deep_factorized.NoisyDeepFactorized()
80153
with self.assertRaises(NotImplementedError):
81154
df.mode()
82155
with self.assertRaises(NotImplementedError):

tensorflow_compression/python/distributions/helpers_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def test_deep_factorized_tails_are_in_order(self):
7777
self.assertAllGreater(
7878
helpers.upper_tail(dist, 2**-8) - helpers.lower_tail(dist, 2**-8), 0)
7979

80+
def test_noisy_deep_factorized_tails_are_in_order(self):
81+
dist = deep_factorized.NoisyDeepFactorized(batch_shape=[10])
82+
self.assertAllGreater(
83+
helpers.upper_tail(dist, 2**-8) - helpers.lower_tail(dist, 2**-8), 0)
8084

8185
if __name__ == "__main__":
8286
tf.test.main()

0 commit comments

Comments
 (0)