Skip to content

Commit 01ff785

Browse files
relationalcopybara-github
authored andcommitted
Add a universally quantized entropy model to tensorflow_compression.
PiperOrigin-RevId: 341791361 Change-Id: Ib80a9ea38f43853d84c76985a8eac0d918d705c7
1 parent e0bc5c5 commit 01ff785

17 files changed

+1387
-149
lines changed

tensorflow_compression/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from tensorflow_compression.python.distributions.uniform_noise import *
3030
from tensorflow_compression.python.entropy_models.continuous_batched import *
3131
from tensorflow_compression.python.entropy_models.continuous_indexed import *
32+
from tensorflow_compression.python.entropy_models.universal import *
3233
from tensorflow_compression.python.layers.entropy_models import *
3334
from tensorflow_compression.python.layers.gdn import *
3435
from tensorflow_compression.python.layers.initializers import *

tensorflow_compression/python/distributions/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ py_test(
8888
deps = [
8989
":deep_factorized",
9090
":round_adapters",
91+
"//tensorflow_compression/python/ops:soft_round_ops",
9192
],
9293
)
9394

tensorflow_compression/python/distributions/helpers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def quantization_offset(distribution):
104104
these are implemented, it falls back on quantizing to integer values (i.e.,
105105
an offset of zero).
106106
107+
Note the offset is always in the range [-.5, .5] as it is assumed to be
108+
combined with a round quantizer.
109+
107110
Arguments:
108111
distribution: A `tfp.distributions.Distribution` object.
109112
@@ -125,7 +128,7 @@ def quantization_offset(distribution):
125128
offset = distribution.mean()
126129
except NotImplementedError:
127130
offset = tf.constant(0, dtype=distribution.dtype)
128-
return tf.stop_gradient(offset)
131+
return tf.stop_gradient(offset - tf.round(offset))
129132

130133

131134
def lower_tail(distribution, tail_mass):

tensorflow_compression/python/distributions/helpers_test.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,29 @@
2323

2424
class HelpersTest(tf.test.TestCase):
2525

26-
def test_cauchy_quantizes_to_mode(self):
27-
dist = tfp.distributions.Cauchy(loc=1.5, scale=3.)
28-
self.assertEqual(helpers.quantization_offset(dist), 1.5)
26+
def test_cauchy_quantizes_to_mode_decimal_part(self):
27+
dist = tfp.distributions.Cauchy(loc=1.4, scale=3.)
28+
self.assertAllClose(helpers.quantization_offset(dist), 0.4)
2929

30-
def test_gamma_quantizes_to_mode(self):
30+
def test_gamma_quantizes_to_mode_decimal_part(self):
3131
dist = tfp.distributions.Gamma(concentration=5., rate=1.)
32-
self.assertEqual(helpers.quantization_offset(dist), 4.)
32+
self.assertEqual(helpers.quantization_offset(dist), 0.)
3333

34-
def test_laplace_quantizes_to_mode(self):
34+
def test_laplace_quantizes_to_mode_decimal_part(self):
3535
dist = tfp.distributions.Laplace(loc=-2., scale=5.)
36-
self.assertEqual(helpers.quantization_offset(dist), -2.)
36+
self.assertEqual(helpers.quantization_offset(dist), 0.)
3737

38-
def test_logistic_quantizes_to_mode(self):
38+
def test_logistic_quantizes_to_mode_decimal_part(self):
3939
dist = tfp.distributions.Logistic(loc=-3., scale=1.)
40-
self.assertEqual(helpers.quantization_offset(dist), -3.)
40+
self.assertEqual(helpers.quantization_offset(dist), 0.)
4141

42-
def test_lognormal_quantizes_to_mode(self):
42+
def test_lognormal_quantizes_to_mode_decimal_part(self):
4343
dist = tfp.distributions.LogNormal(loc=4., scale=1.)
44-
self.assertEqual(helpers.quantization_offset(dist), tf.exp(3.))
44+
self.assertAllClose(helpers.quantization_offset(dist), tf.exp(3.)-20.0)
4545

46-
def test_normal_quantizes_to_mode(self):
46+
def test_normal_quantizes_to_mode_decimal_part(self):
4747
dist = tfp.distributions.Normal(loc=3., scale=5.)
48-
self.assertEqual(helpers.quantization_offset(dist), 3.)
48+
self.assertEqual(helpers.quantization_offset(dist), 0.)
4949

5050
def test_cauchy_tails_are_in_order(self):
5151
dist = tfp.distributions.Cauchy(loc=1.5, scale=3.)

tensorflow_compression/python/distributions/round_adapters_test.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from tensorflow_compression.python.distributions import deep_factorized
2323
from tensorflow_compression.python.distributions import round_adapters
24+
from tensorflow_compression.python.ops import soft_round_ops
2425

2526

2627
def _test_log_prob_gradient_is_bounded(self, dist_cls, values, params=()):
@@ -49,19 +50,21 @@ class AdaptersTest(tf.test.TestCase, parameterized.TestCase):
4950
deep_factorized.DeepFactorized, 0.0),
5051
("softround_logistic",
5152
lambda d: round_adapters.SoftRoundAdapter(d, alpha=5.0),
52-
lambda: tfp.distributions.Logistic(loc=10, scale=1.5), 10.0),
53+
lambda: tfp.distributions.Logistic(loc=10.3, scale=1.5),
54+
lambda: soft_round_ops.soft_round(0.3, alpha=5.0)),
5355
("softround_normal",
54-
lambda d: round_adapters.SoftRoundAdapter(d, alpha=5.0),
55-
lambda: tfp.distributions.Normal(loc=10, scale=1.5), 10.0),
56+
lambda d: round_adapters.SoftRoundAdapter(d, alpha=4.0),
57+
lambda: tfp.distributions.Normal(loc=10.4, scale=1.5),
58+
lambda: soft_round_ops.soft_round(0.4, alpha=4.0)),
5659
("noisysoftround_deepfactorized",
5760
lambda d: round_adapters.NoisySoftRoundAdapter(d, alpha=5.0),
5861
deep_factorized.DeepFactorized, 0.0),
5962
("noisysoftround_logistic",
6063
lambda d: round_adapters.NoisySoftRoundAdapter(d, alpha=5.0),
61-
lambda: tfp.distributions.Logistic(loc=10, scale=1.5), 10.0),
64+
lambda: tfp.distributions.Logistic(loc=10, scale=1.5), 0.0),
6265
("noisysoftround_normal",
6366
lambda d: round_adapters.NoisySoftRoundAdapter(d, alpha=5.0),
64-
lambda: tfp.distributions.Normal(loc=10, scale=1.5), 10.0),
67+
lambda: tfp.distributions.Normal(loc=10, scale=1.5), 0.0),
6568
("round_deepfactorized",
6669
round_adapters.RoundAdapter,
6770
lambda: deep_factorized.DeepFactorized(init_scale=1.0), 0.0),
@@ -101,6 +104,10 @@ def test_tails_and_offset(self, adapter, distribution, expected_offset):
101104

102105
self.assertGreater(upper_tail, lower_tail)
103106
offset = dist._quantization_offset()
107+
if not isinstance(expected_offset, float):
108+
# We cannot run tf inside the parameterized test declaration, hence
109+
# non-float values are wrapped in a lambda.
110+
expected_offset = expected_offset()
104111
self.assertAllClose(offset, expected_offset)
105112

106113
@parameterized.named_parameters(

tensorflow_compression/python/distributions/uniform_noise_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ def test_sampling_works(self):
6060
self.assertEqual(sample.shape, (5, 4, 2))
6161

6262
def test_tails_and_offset_are_in_order(self):
63-
dist = self.dist_cls(loc=10, scale=1.5)
63+
dist = self.dist_cls(loc=10.3, scale=1.5)
6464
offset = helpers.quantization_offset(dist)
6565
lower_tail = helpers.lower_tail(dist, 2**-8)
6666
upper_tail = helpers.upper_tail(dist, 2**-8)
67-
self.assertGreater(upper_tail, offset)
68-
self.assertGreater(offset, lower_tail)
67+
self.assertGreater(upper_tail, lower_tail)
68+
self.assertAllClose(offset, 0.3)
6969

7070
def test_stats_throw_error(self):
7171
dist = self.dist_cls(loc=1, scale=2)
@@ -130,12 +130,12 @@ def test_sampling_works(self):
130130
self.assertEqual(sample.shape, (5, 4, 1))
131131

132132
def test_tails_and_offset_are_in_order(self):
133-
dist = self.dist_cls(loc=10, scale=[1.5, 2], weight=[.5, .5])
133+
dist = self.dist_cls(loc=[5.4, 8.6], scale=[1.4, 2], weight=[.6, .4])
134134
offset = helpers.quantization_offset(dist)
135135
lower_tail = helpers.lower_tail(dist, 2**-8)
136136
upper_tail = helpers.upper_tail(dist, 2**-8)
137-
self.assertGreater(upper_tail, offset)
138-
self.assertGreater(offset, lower_tail)
137+
self.assertGreater(upper_tail, lower_tail)
138+
self.assertAllClose(offset, 0.4) # Decimal part of the peakiest mode (5.4).
139139

140140
def test_stats_throw_error(self):
141141
dist = self.dist_cls(loc=[1, 0], scale=2, weight=[.1, .9])

tensorflow_compression/python/entropy_models/BUILD

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ py_library(
1212
":continuous_base",
1313
":continuous_batched",
1414
":continuous_indexed",
15+
":universal",
1516
],
1617
)
1718

@@ -32,6 +33,7 @@ py_library(
3233
deps = [
3334
":continuous_base",
3435
"//tensorflow_compression/python/distributions:helpers",
36+
"//tensorflow_compression/python/ops:math_ops",
3537
"//tensorflow_compression/python/ops:range_coding_ops",
3638
],
3739
)
@@ -68,6 +70,30 @@ py_test(
6870
],
6971
)
7072

73+
py_library(
74+
name = "universal",
75+
srcs = ["universal.py"],
76+
srcs_version = "PY3",
77+
deps = [
78+
":continuous_batched",
79+
":continuous_indexed",
80+
"//tensorflow_compression/python/ops:math_ops",
81+
],
82+
)
83+
84+
py_test(
85+
name = "universal_test",
86+
timeout = "long",
87+
srcs = ["universal_test.py"],
88+
python_version = "PY3",
89+
shard_count = 3,
90+
deps = [
91+
":universal",
92+
"//tensorflow_compression/python/distributions:deep_factorized",
93+
"//tensorflow_compression/python/distributions:uniform_noise",
94+
],
95+
)
96+
7197
filegroup(
7298
name = "py_src",
7399
srcs = glob(["*.py"]),

0 commit comments

Comments
 (0)