Skip to content

Commit f3ca873

Browse files
relationalcopybara-github
authored andcommitted
Update (experimental) laplace_tail_mass to use NoisyLaplace instead of Laplace.
This is because in some cases we assume that the distribution is a valid PMF when evaluated at integer-spaced locations. PiperOrigin-RevId: 460402710 Change-Id: Idac61b6865453d906ee93334437905c4eadb750f
1 parent f93c3e8 commit f3ca873

File tree

7 files changed

+33
-11
lines changed

7 files changed

+33
-11
lines changed

tensorflow_compression/python/distributions/uniform_noise.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"NoisyMixtureSameFamily",
2525
"NoisyNormal",
2626
"NoisyLogistic",
27+
"NoisyLaplace",
2728
"NoisyNormalMixture",
2829
"NoisyLogisticMixture",
2930
]
@@ -265,6 +266,14 @@ def __init__(self, name="NoisyLogistic", **kwargs):
265266
super().__init__(tfp.distributions.Logistic(**kwargs), name=name)
266267

267268

269+
class NoisyLaplace(UniformNoiseAdapter):
270+
"""Laplacian distribution with additive i.i.d. uniform noise."""
271+
272+
def __init__(self, name="NoisyLaplace", **kwargs):
273+
"""Initializer, taking the same arguments as `tfpd.Laplace`."""
274+
super().__init__(tfp.distributions.Laplace(**kwargs), name=name)
275+
276+
268277
class NoisyNormalMixture(NoisyMixtureSameFamily):
269278
"""Mixture of normal distributions with additive i.i.d. uniform noise."""
270279

tensorflow_compression/python/distributions/uniform_noise_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ def test_uniform_is_special_case(self):
5353
x = tf.linspace(mean - 1, mean + 1, 10)
5454
self.assertAllClose(dist.prob(x), [0, 0, 0, 1, 1, 1, 1, 0, 0, 0])
5555

56+
def test_is_pmf_at_integer_grid(self):
57+
# When evaluated at integers, P_i = p(c+i) should be a valid PMF.
58+
dist = self.dist_cls(loc=0.1, scale=0.3)
59+
x = 0.05 + tf.range(-100, 100, dtype=tf.float32)
60+
pmf_probs = dist.prob(x)
61+
self.assertAllClose(tf.reduce_sum(pmf_probs), 1.0)
62+
5663
def test_sampling_works(self):
5764
dist = self.dist_cls(loc=0, scale=[3, 5])
5865
sample = dist.sample((5, 4))
@@ -90,6 +97,11 @@ class NoisyLogisticTest(LocationScaleTest, tf.test.TestCase):
9097
dist_cls = uniform_noise.NoisyLogistic
9198

9299

100+
class NoisyLaplaceTest(LocationScaleTest, tf.test.TestCase):
101+
102+
dist_cls = uniform_noise.NoisyLaplace
103+
104+
93105
class MixtureTest:
94106
"""Common tests for noisy mixture distributions."""
95107

tensorflow_compression/python/entropy_models/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ py_library(
2020
srcs = ["continuous_base.py"],
2121
deps = [
2222
"//tensorflow_compression/python/distributions:helpers",
23+
"//tensorflow_compression/python/distributions:uniform_noise",
2324
"//tensorflow_compression/python/ops:gen_ops",
2425
],
2526
)

tensorflow_compression/python/entropy_models/continuous_base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import abc
1818
from absl import logging
1919
import tensorflow as tf
20-
import tensorflow_probability as tfp
2120
from tensorflow_compression.python.distributions import helpers
21+
from tensorflow_compression.python.distributions import uniform_noise
2222
from tensorflow_compression.python.ops import gen_ops
2323

2424

@@ -69,7 +69,7 @@ def __init__(self,
6969
bottleneck_dtype: `tf.dtypes.DType`. Data type of bottleneck tensor.
7070
Defaults to `tf.keras.mixed_precision.global_policy().compute_dtype`.
7171
laplace_tail_mass: Float. If non-zero, will augment the prior with a
72-
Laplace mixture for training stability. (experimental)
72+
`NoisyLaplace` mixture component for training stability. (experimental)
7373
"""
7474
super().__init__()
7575
self._prior = None # This will be set by subclasses, if appropriate.
@@ -137,7 +137,7 @@ def expected_grads(self):
137137

138138
@property
139139
def laplace_tail_mass(self):
140-
"""Whether to augment the prior with a Laplace mixture."""
140+
"""Whether to augment the prior with a `NoisyLaplace` mixture."""
141141
return self._laplace_tail_mass
142142

143143
@property
@@ -298,7 +298,7 @@ def _log_prob(self, prior, bottleneck_perturbed):
298298
"""Evaluates prior.log_prob(bottleneck + noise)."""
299299
bottleneck_perturbed = tf.cast(bottleneck_perturbed, prior.dtype)
300300
if self.laplace_tail_mass:
301-
laplace_prior = tfp.distributions.Laplace(
301+
laplace_prior = uniform_noise.NoisyLaplace(
302302
loc=tf.constant(0, dtype=prior.dtype),
303303
scale=tf.constant(1, dtype=prior.dtype))
304304
probs = prior.prob(bottleneck_perturbed)

tensorflow_compression/python/entropy_models/continuous_batched.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,8 @@ def __init__(self,
171171
use. If provided (not `None`), then `offset_heuristic` is ineffective.
172172
decode_sanity_check: Boolean. If `True`, an raises an error if the binary
173173
strings passed into `decompress` are not completely decoded.
174-
laplace_tail_mass: Float. If positive, will augment the prior with a
175-
Laplace mixture for training stability. (experimental)
174+
laplace_tail_mass: Float. If non-zero, will augment the prior with a
175+
`NoisyLaplace` mixture component for training stability. (experimental)
176176
"""
177177
if (prior is None) == (prior_shape is None):
178178
raise ValueError("Either `prior` or `prior_shape` must be provided.")

tensorflow_compression/python/entropy_models/continuous_indexed.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,8 @@ def __init__(self,
189189
computations. Defaults to `tf.float32`.
190190
decode_sanity_check: Boolean. If `True`, an raises an error if the binary
191191
strings passed into `decompress` are not completely decoded.
192-
laplace_tail_mass: Float. If positive, will augment the prior with a
193-
laplace mixture for training stability. (experimental)
192+
laplace_tail_mass: Float. If non-zero, will augment the prior with a
193+
`NoisyLaplace` mixture component for training stability. (experimental)
194194
"""
195195
if not callable(prior_fn):
196196
raise TypeError("`prior_fn` must be a class or factory function.")
@@ -496,8 +496,8 @@ def __init__(self,
496496
Defaults to `tf.keras.mixed_precision.global_policy().compute_dtype`.
497497
prior_dtype: `tf.dtypes.DType`. Data type of prior and probability
498498
computations. Defaults to `tf.float32`.
499-
laplace_tail_mass: Float. If positive, will augment the prior with a
500-
laplace mixture for training stability. (experimental)
499+
laplace_tail_mass: Float. If non-zero, will augment the prior with a
500+
`NoisyLaplace` mixture component for training stability. (experimental)
501501
"""
502502
num_scales = int(num_scales)
503503
super().__init__(

tensorflow_compression/python/entropy_models/universal_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_laplace_tail_mass_for_large_inputs(self):
9898
coding_rank=1,
9999
compression=True,
100100
laplace_tail_mass=1e-3)
101-
x = tf.convert_to_tensor([1e3, 1e4, 1e5, 1e6, 1e7, 1e8], tf.float32)
101+
x = tf.convert_to_tensor([1e3, 1e4, 1e5, 1e6], tf.float32)
102102
_, bits = em(x[..., None])
103103
self.assertAllClose(bits, tf.abs(x) / tf.math.log(2.0), rtol=0.01)
104104

0 commit comments

Comments
 (0)