Skip to content

Commit 9054647

Browse files
Johannes Ballécopybara-github
authored andcommitted
Removes likelihood_bound argument.
With the current set of density models, it seems the numerical fix isn't necessary any more. Moreover, this can cause gradients to stagnate in low probability regions. PiperOrigin-RevId: 340650635 Change-Id: Ib48b50e9a88526d67574705fa2e9e3a035a3ebb5
1 parent 424ef03 commit 9054647

File tree

6 files changed

+11
-41
lines changed

6 files changed

+11
-41
lines changed

tensorflow_compression/python/entropy_models/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ py_library(
3232
deps = [
3333
":continuous_base",
3434
"//tensorflow_compression/python/distributions:helpers",
35-
"//tensorflow_compression/python/ops:math_ops",
3635
"//tensorflow_compression/python/ops:range_coding_ops",
3736
],
3837
)

tensorflow_compression/python/entropy_models/continuous_base.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ class ContinuousEntropyModelBase(tf.Module, metaclass=abc.ABCMeta):
3838

3939
@abc.abstractmethod
4040
def __init__(self, prior, coding_rank, compression=False,
41-
likelihood_bound=1e-9, tail_mass=2**-8,
42-
range_coder_precision=12, no_variables=False):
41+
tail_mass=2**-8, range_coder_precision=12, no_variables=False):
4342
"""Initializer.
4443
4544
Arguments:
@@ -55,8 +54,6 @@ def __init__(self, prior, coding_rank, compression=False,
5554
compression: Boolean. If set to `True`, the range coding tables used by
5655
`compress()` and `decompress()` will be built on instantiation. If set
5756
to `False`, these two methods will not be accessible.
58-
likelihood_bound: Float. Lower bound for likelihood values, to prevent
59-
training instabilities.
6057
tail_mass: Float. Approximate probability mass which is range encoded with
6158
less precision, by using a Golomb-like code.
6259
range_coder_precision: Integer. Precision passed to the range coding op.
@@ -76,7 +73,6 @@ def __init__(self, prior, coding_rank, compression=False,
7673
self._prior_shape = tuple(int(s) for s in prior.batch_shape)
7774
self._coding_rank = int(coding_rank)
7875
self._compression = bool(compression)
79-
self._likelihood_bound = float(likelihood_bound)
8076
self._tail_mass = float(tail_mass)
8177
self._range_coder_precision = int(range_coder_precision)
8278
self._no_variables = bool(no_variables)
@@ -143,11 +139,6 @@ def compression(self):
143139
"""Whether this entropy model is prepared for compression."""
144140
return self._compression
145141

146-
@property
147-
def likelihood_bound(self):
148-
"""Lower bound for likelihood values."""
149-
return self._likelihood_bound
150-
151142
@property
152143
def tail_mass(self):
153144
"""Approximate probability mass which is range encoded with overflow."""
@@ -285,7 +276,6 @@ def get_config(self):
285276
dtype=self._dtype.name,
286277
prior_shape=self._prior_shape,
287278
coding_rank=self._coding_rank,
288-
likelihood_bound=self._likelihood_bound,
289279
tail_mass=self._tail_mass,
290280
range_coder_precision=self._range_coder_precision,
291281
cdf_width=self._cdf.shape.as_list()[1],
@@ -314,7 +304,6 @@ def from_config(cls, config):
314304
self._prior_shape = tuple(int(s) for s in config["prior_shape"])
315305
self._coding_rank = int(config["coding_rank"])
316306
self._compression = True
317-
self._likelihood_bound = float(config["likelihood_bound"])
318307
self._tail_mass = float(config["tail_mass"])
319308
self._range_coder_precision = int(config["range_coder_precision"])
320309
self._no_variables = False

tensorflow_compression/python/entropy_models/continuous_batched.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from tensorflow_compression.python.distributions import helpers
2222
from tensorflow_compression.python.entropy_models import continuous_base
23-
from tensorflow_compression.python.ops import math_ops
2423
from tensorflow_compression.python.ops import range_coding_ops
2524

2625

@@ -74,8 +73,7 @@ class ContinuousBatchedEntropyModel(continuous_base.ContinuousEntropyModelBase):
7473
"""
7574

7675
def __init__(self, prior, coding_rank, compression=False,
77-
likelihood_bound=1e-9, tail_mass=2**-8,
78-
range_coder_precision=12, no_variables=False):
76+
tail_mass=2**-8, range_coder_precision=12, no_variables=False):
7977
"""Initializer.
8078
8179
Arguments:
@@ -93,8 +91,6 @@ def __init__(self, prior, coding_rank, compression=False,
9391
compression: Boolean. If set to `True`, the range coding tables used by
9492
`compress()` and `decompress()` will be built on instantiation. If set
9593
to `False`, these two methods will not be accessible.
96-
likelihood_bound: Float. Lower bound for likelihood values, to prevent
97-
training instabilities.
9894
tail_mass: Float. Approximate probability mass which is range encoded with
9995
less precision, by using a Golomb-like code.
10096
range_coder_precision: Integer. Precision passed to the range coding op.
@@ -112,7 +108,6 @@ def __init__(self, prior, coding_rank, compression=False,
112108
prior=prior,
113109
coding_rank=coding_rank,
114110
compression=compression,
115-
likelihood_bound=likelihood_bound,
116111
tail_mass=tail_mass,
117112
range_coder_precision=range_coder_precision,
118113
no_variables=no_variables,
@@ -173,11 +168,10 @@ def bits(self, bottleneck, training=True):
173168
tf.shape(bottleneck), minval=-.5, maxval=.5, dtype=bottleneck.dtype)
174169
else:
175170
quantized = self.quantize(bottleneck)
176-
probs = self.prior.prob(quantized)
177-
probs = math_ops.lower_bound(probs, self.likelihood_bound)
171+
log_probs = self.prior.log_prob(quantized)
178172
axes = tuple(range(-self.coding_rank, 0))
179-
bits = tf.reduce_sum(tf.math.log(probs), axis=axes) / (
180-
-tf.math.log(tf.constant(2., dtype=probs.dtype)))
173+
bits = tf.reduce_sum(log_probs, axis=axes) / (
174+
-tf.math.log(tf.constant(2, dtype=log_probs.dtype)))
181175
return bits
182176

183177
@tf.Module.with_name_scope

tensorflow_compression/python/entropy_models/continuous_batched_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def test_can_instantiate(self):
2828
em = ContinuousBatchedEntropyModel(noisy, 1)
2929
self.assertIs(em.prior, noisy)
3030
self.assertEqual(em.coding_rank, 1)
31-
self.assertEqual(em.likelihood_bound, 1e-9)
3231
self.assertEqual(em.tail_mass, 2**-8)
3332
self.assertEqual(em.range_coder_precision, 12)
3433
self.assertEqual(em.dtype, noisy.dtype)

tensorflow_compression/python/entropy_models/continuous_indexed.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,7 @@ class ContinuousIndexedEntropyModel(continuous_base.ContinuousEntropyModelBase):
128128

129129
def __init__(self, prior_fn, index_ranges, parameter_fns, coding_rank,
130130
compression=False, channel_axis=-1, dtype=tf.float32,
131-
likelihood_bound=1e-9, tail_mass=2**-8,
132-
range_coder_precision=12, no_variables=False):
131+
tail_mass=2**-8, range_coder_precision=12, no_variables=False):
133132
"""Initializer.
134133
135134
Arguments:
@@ -165,8 +164,6 @@ def __init__(self, prior_fn, index_ranges, parameter_fns, coding_rank,
165164
dimension.
166165
dtype: `tf.dtypes.DType`. The data type of all floating-point
167166
computations carried out in this class.
168-
likelihood_bound: Float. Lower bound for likelihood values, to prevent
169-
training instabilities.
170167
tail_mass: Float. Approximate probability mass which is range encoded with
171168
less precision, by using a Golomb-like code.
172169
range_coder_precision: Integer. Precision passed to the range coding op.
@@ -209,7 +206,6 @@ def __init__(self, prior_fn, index_ranges, parameter_fns, coding_rank,
209206
prior=prior,
210207
coding_rank=coding_rank,
211208
compression=compression,
212-
likelihood_bound=likelihood_bound,
213209
tail_mass=tail_mass,
214210
range_coder_precision=range_coder_precision,
215211
no_variables=no_variables,
@@ -285,11 +281,10 @@ def bits(self, bottleneck, indexes, training=True):
285281
else:
286282
offset = helpers.quantization_offset(prior)
287283
quantized = self._quantize(bottleneck, offset)
288-
probs = prior.prob(quantized)
289-
probs = math_ops.lower_bound(probs, self.likelihood_bound)
284+
log_probs = prior.log_prob(quantized)
290285
axes = tuple(range(-self.coding_rank, 0))
291-
bits = tf.reduce_sum(tf.math.log(probs), axis=axes) / (
292-
-tf.math.log(tf.constant(2., dtype=probs.dtype)))
286+
bits = tf.reduce_sum(log_probs, axis=axes) / (
287+
-tf.math.log(tf.constant(2, dtype=log_probs.dtype)))
293288
return bits
294289

295290
@tf.Module.with_name_scope
@@ -439,8 +434,8 @@ class LocationScaleIndexedEntropyModel(ContinuousIndexedEntropyModel):
439434
"""
440435

441436
def __init__(self, prior_fn, num_scales, scale_fn, coding_rank,
442-
compression=False, dtype=tf.float32, likelihood_bound=1e-9,
443-
tail_mass=2**-8, range_coder_precision=12, no_variables=False):
437+
compression=False, dtype=tf.float32, tail_mass=2**-8,
438+
range_coder_precision=12, no_variables=False):
444439
"""Initializer.
445440
446441
Arguments:
@@ -466,8 +461,6 @@ def __init__(self, prior_fn, num_scales, scale_fn, coding_rank,
466461
be accessible.
467462
dtype: `tf.dtypes.DType`. The data type of all floating-point
468463
computations carried out in this class.
469-
likelihood_bound: Float. Lower bound for likelihood values, to prevent
470-
training instabilities.
471464
tail_mass: Float. Approximate probability mass which is range encoded with
472465
less precision, by using a Golomb-like code.
473466
range_coder_precision: Integer. Precision passed to the range coding op.
@@ -485,7 +478,6 @@ def __init__(self, prior_fn, num_scales, scale_fn, coding_rank,
485478
coding_rank=coding_rank,
486479
compression=compression,
487480
dtype=dtype,
488-
likelihood_bound=likelihood_bound,
489481
tail_mass=tail_mass,
490482
range_coder_precision=range_coder_precision,
491483
no_variables=no_variables,

tensorflow_compression/python/entropy_models/continuous_indexed_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def test_can_instantiate_one_dimensional(self):
3131
dict(loc=lambda _: 0, scale=lambda i: tf.exp(i / 8 - 5)), 1)
3232
self.assertIsInstance(em.prior, uniform_noise.NoisyNormal)
3333
self.assertEqual(em.coding_rank, 1)
34-
self.assertEqual(em.likelihood_bound, 1e-9)
3534
self.assertEqual(em.tail_mass, 2**-8)
3635
self.assertEqual(em.range_coder_precision, 12)
3736
self.assertEqual(em.dtype, tf.float32)
@@ -50,7 +49,6 @@ def test_can_instantiate_n_dimensional(self):
5049
self.assertIsInstance(em.prior, uniform_noise.NoisyLogisticMixture)
5150
self.assertEqual(em.coding_rank, 1)
5251
self.assertEqual(em.channel_axis, -1)
53-
self.assertEqual(em.likelihood_bound, 1e-9)
5452
self.assertEqual(em.tail_mass, 2**-8)
5553
self.assertEqual(em.range_coder_precision, 12)
5654
self.assertEqual(em.dtype, tf.float32)
@@ -63,7 +61,6 @@ def test_can_instantiate(self):
6361
uniform_noise.NoisyNormal, 64, lambda i: tf.exp(i / 8 - 5), 1)
6462
self.assertIsInstance(em.prior, uniform_noise.NoisyNormal)
6563
self.assertEqual(em.coding_rank, 1)
66-
self.assertEqual(em.likelihood_bound, 1e-9)
6764
self.assertEqual(em.tail_mass, 2**-8)
6865
self.assertEqual(em.range_coder_precision, 12)
6966
self.assertEqual(em.dtype, tf.float32)

0 commit comments

Comments
 (0)