Skip to content

Commit 424ef03

Browse files
Johannes Ballécopybara-github
authored andcommitted
Fixes for DeepFactorized.
- Implements forgotten survival function. - Fixes incorrect prob() and log_prob(): We need to broadcast the inputs before calling tape.gradient(), or else the return value will be a reduced sum over the dimensions that happened to be broadcast. - Removes an unused attribute and makes the code a little more readable. - Adds custom lower_tail and upper_tail methods that are better behaved. PiperOrigin-RevId: 340299486 Change-Id: I9c9287eb3a3eaca021bb85b0930377c84a7baf5b
1 parent eeb9a0f commit 424ef03

File tree

3 files changed

+62
-64
lines changed

3 files changed

+62
-64
lines changed

tensorflow_compression/python/distributions/BUILD

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ py_library(
2020
name = "deep_factorized",
2121
srcs = ["deep_factorized.py"],
2222
srcs_version = "PY3",
23-
deps = [":uniform_noise"],
23+
deps = [
24+
":helpers",
25+
":uniform_noise",
26+
],
2427
)
2528

2629
py_test(

tensorflow_compression/python/distributions/deep_factorized.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import tensorflow.compat.v2 as tf
1818
import tensorflow_probability as tfp
1919

20+
from tensorflow_compression.python.distributions import helpers
2021
from tensorflow_compression.python.distributions import uniform_noise
2122

2223

@@ -89,7 +90,6 @@ def __init__(self,
8990
self._batch_shape_tuple = tuple(int(s) for s in batch_shape)
9091
self._num_filters = tuple(int(f) for f in num_filters)
9192
self._init_scale = float(init_scale)
92-
self._estimated_tail_mass = None
9393
super().__init__(
9494
dtype=dtype,
9595
reparameterization_type=tfp.distributions.NOT_REPARAMETERIZED,
@@ -130,10 +130,8 @@ def matrix_initializer(i=i):
130130
self._matrices.append(matrix)
131131

132132
def bias_initializer(i=i):
133-
return tf.random.uniform((channels, filters[i + 1], 1),
134-
-.5,
135-
.5,
136-
dtype=self.dtype)
133+
return tf.random.uniform(
134+
(channels, filters[i + 1], 1), -.5, .5, dtype=self.dtype)
137135

138136
bias = tf.Variable(bias_initializer, name="bias_{}".format(i))
139137
self._biases.append(bias)
@@ -158,6 +156,11 @@ def _event_shape_tensor(self):
158156
def _event_shape(self):
159157
return tf.TensorShape(())
160158

159+
def _broadcast_inputs(self, inputs):
160+
shape = tf.broadcast_dynamic_shape(
161+
tf.shape(inputs), self.batch_shape_tensor())
162+
return tf.broadcast_to(inputs, shape)
163+
161164
def _logits_cumulative(self, inputs):
162165
"""Evaluate logits of the cumulative densities.
163166
@@ -170,9 +173,6 @@ def _logits_cumulative(self, inputs):
170173
"""
171174
# Convert to (channels, 1, batch) format by collapsing dimensions and then
172175
# commuting channels to front.
173-
inputs = tf.broadcast_to(
174-
inputs,
175-
tf.broadcast_dynamic_shape(tf.shape(inputs), self.batch_shape_tensor()))
176176
shape = tf.shape(inputs)
177177
inputs = tf.reshape(inputs, (-1, 1, self.batch_shape.num_elements()))
178178
inputs = tf.transpose(inputs, (2, 1, 0))
@@ -191,35 +191,46 @@ def _logits_cumulative(self, inputs):
191191
return logits
192192

193193
def _log_cdf(self, inputs):
194+
inputs = self._broadcast_inputs(inputs)
194195
logits = self._logits_cumulative(inputs)
195196
return tf.math.log_sigmoid(logits)
196197

197198
def _log_survival_function(self, inputs):
199+
inputs = self._broadcast_inputs(inputs)
198200
logits = self._logits_cumulative(inputs)
199201
# 1-sigmoid(x) = sigmoid(-x)
200202
return tf.math.log_sigmoid(-logits)
201203

202204
def _cdf(self, inputs):
205+
inputs = self._broadcast_inputs(inputs)
203206
logits = self._logits_cumulative(inputs)
204207
return tf.math.sigmoid(logits)
205208

209+
def _survival_function(self, inputs):
210+
inputs = self._broadcast_inputs(inputs)
211+
logits = self._logits_cumulative(inputs)
212+
# 1-sigmoid(x) = sigmoid(-x)
213+
return tf.math.sigmoid(-logits)
214+
206215
def _prob(self, inputs):
207-
with tf.GradientTape() as tape:
216+
inputs = self._broadcast_inputs(inputs)
217+
with tf.GradientTape(watch_accessed_variables=False) as tape:
208218
tape.watch(inputs)
209219
cdf = self._cdf(inputs)
210220
prob = tape.gradient(cdf, inputs)
211221
return prob
212222

213223
def _log_prob(self, inputs):
214-
# let x=inputs and s(x)=sigmoid(x).
215-
with tf.GradientTape() as tape:
224+
inputs = self._broadcast_inputs(inputs)
225+
with tf.GradientTape(watch_accessed_variables=False) as tape:
216226
tape.watch(inputs)
217227
logits = self._logits_cumulative(inputs)
218-
# We have F(x) = s(logits(x))
228+
# Let x=inputs and s(x)=sigmoid(x).
229+
# We have F(x) = s(logits(x)),
219230
# so p(x) = F'(x)
220231
# = s'(logits(x)) * logits'(x)
221232
# = s(logits(x))*s(-logits(x)) * logits'(x)
222-
# so log p(x) = log(s(logits(x)) + log(s(-logits(x)) + log(logits'(x))
233+
# so log p(x) = log(s(logits(x)) + log(s(-logits(x)) + log(logits'(x)).
223234
log_s_logits = tf.math.log_sigmoid(logits)
224235
log_s_neg_logits = tf.math.log_sigmoid(-logits)
225236
dlogits = tape.gradient(logits, inputs)
@@ -228,6 +239,16 @@ def _log_prob(self, inputs):
228239
def _quantization_offset(self):
229240
return tf.constant(0, dtype=self.dtype)
230241

242+
def _lower_tail(self, tail_mass):
243+
logits = tf.math.log(tail_mass / 2 / (1. - tail_mass / 2))
244+
return helpers.estimate_tails(
245+
self._logits_cumulative, logits, self.batch_shape_tensor(), self.dtype)
246+
247+
def _upper_tail(self, tail_mass):
248+
logits = -tf.math.log(tail_mass / 2 / (1. - tail_mass / 2))
249+
return helpers.estimate_tails(
250+
self._logits_cumulative, logits, self.batch_shape_tensor(), self.dtype)
251+
231252

232253
class NoisyDeepFactorized(uniform_noise.UniformNoiseAdapter):
233254
"""DeepFactorized that is convolved with uniform noise."""

tensorflow_compression/python/distributions/deep_factorized_test.py

Lines changed: 24 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414
# ==============================================================================
1515
"""Tests of deep factorized distribution."""
1616

17+
from absl.testing import parameterized
1718
import tensorflow.compat.v2 as tf
1819
import tensorflow_probability as tfp
1920

2021
from tensorflow_compression.python.distributions import deep_factorized
2122
from tensorflow_compression.python.distributions import helpers
2223

2324

24-
class DeepFactorizedTest(tf.test.TestCase):
25+
class DeepFactorizedTest(tf.test.TestCase, parameterized.TestCase):
2526

2627
def test_can_instantiate_scalar(self):
2728
df = deep_factorized.DeepFactorized()
@@ -37,56 +38,31 @@ def test_can_instantiate_batched(self):
3738
self.assertEqual(df.num_filters, (3, 3))
3839
self.assertEqual(df.init_scale, 10)
3940

40-
def test_logistic_is_special_case_prob(self):
41+
@parameterized.parameters(
42+
"prob", "log_prob",
43+
"cdf", "log_cdf",
44+
"survival_function", "log_survival_function",
45+
)
46+
def test_logistic_is_special_case(self, method):
4147
# With no hidden units, the density should collapse to a logistic
4248
# distribution.
4349
df = deep_factorized.DeepFactorized(num_filters=(), init_scale=1)
4450
logistic = tfp.distributions.Logistic(loc=-df._biases[0][0, 0], scale=1.)
4551
x = tf.linspace(-5., 5., 20)
46-
prob_df = df.prob(x)
47-
prob_logistic = logistic.prob(x)
48-
self.assertAllClose(prob_df, prob_logistic)
49-
50-
def test_logistic_is_special_case_cdf(self):
51-
# With no hidden units, the density should collapse to a logistic
52-
# distribution.
53-
df = deep_factorized.DeepFactorized(num_filters=(), init_scale=1)
54-
logistic = tfp.distributions.Logistic(loc=-df._biases[0][0, 0], scale=1.)
55-
x = tf.linspace(-5., 5., 20)
56-
cdf_df = df.cdf(x)
57-
cdf_logistic = logistic.cdf(x)
58-
self.assertAllClose(cdf_df, cdf_logistic)
59-
60-
def test_logistic_is_special_case_log_prob(self):
61-
# With no hidden units, the density should collapse to a logistic
62-
# distribution.
63-
df = deep_factorized.DeepFactorized(num_filters=(), init_scale=1)
64-
logistic = tfp.distributions.Logistic(loc=-df._biases[0][0, 0], scale=1.)
65-
x = tf.linspace(-5000., 5000., 1000)
66-
log_prob_df = df.log_prob(x)
67-
log_prob_logistic = logistic.log_prob(x)
68-
self.assertAllClose(log_prob_df, log_prob_logistic)
69-
70-
def test_logistic_is_special_case_log_cdf(self):
71-
# With no hidden units, the density should collapse to a logistic
72-
# distribution.
73-
df = deep_factorized.DeepFactorized(num_filters=(), init_scale=1)
74-
logistic = tfp.distributions.Logistic(loc=-df._biases[0][0, 0], scale=1.)
75-
x = tf.linspace(-5000., 5000., 1000)
76-
log_cdf_df = df.log_cdf(x)
77-
log_cdf_logistic = logistic.log_cdf(x)
78-
self.assertAllClose(log_cdf_df, log_cdf_logistic)
79-
80-
def test_logistic_is_special_case_log_survival_function(self):
81-
# With no hidden units, the density should collapse to a logistic
82-
# distribution.
83-
df = deep_factorized.DeepFactorized(num_filters=(), init_scale=1)
84-
logistic = tfp.distributions.Logistic(loc=-df._biases[0][0, 0], scale=1.)
85-
x = tf.linspace(-5000., 5000., 1000)
86-
log_survival_function_df = df.log_survival_function(x)
87-
log_survival_function_logistic = logistic.log_survival_function(x)
88-
self.assertAllClose(log_survival_function_df,
89-
log_survival_function_logistic)
52+
val_df = getattr(df, method)(x)
53+
val_logistic = getattr(logistic, method)(x)
54+
self.assertAllClose(val_df, val_logistic)
55+
56+
@parameterized.parameters(
57+
"prob", "log_prob",
58+
"cdf", "log_cdf",
59+
"survival_function", "log_survival_function",
60+
)
61+
def test_broadcasts_correctly(self, method):
62+
df = deep_factorized.DeepFactorized(batch_shape=(2, 3))
63+
x = tf.reshape(tf.linspace(-5., 5., 20), (4, 5, 1, 1))
64+
val = getattr(df, method)(x)
65+
self.assertEqual(val.shape, (4, 5, 2, 3))
9066

9167

9268
class NoisyDeepFactorizedTest(tf.test.TestCase):
@@ -140,13 +116,11 @@ def test_quantization_offset_is_zero(self):
140116
df = deep_factorized.NoisyDeepFactorized()
141117
self.assertEqual(helpers.quantization_offset(df), 0)
142118

143-
def test_tails_and_offset_are_in_order(self):
119+
def test_tails_are_in_order(self):
144120
df = deep_factorized.NoisyDeepFactorized()
145-
offset = helpers.quantization_offset(df)
146121
lower_tail = helpers.lower_tail(df, 2**-8)
147122
upper_tail = helpers.upper_tail(df, 2**-8)
148-
self.assertGreater(upper_tail, offset)
149-
self.assertGreater(offset, lower_tail)
123+
self.assertGreater(upper_tail, lower_tail)
150124

151125
def test_stats_throw_error(self):
152126
df = deep_factorized.NoisyDeepFactorized()

0 commit comments

Comments
 (0)