Skip to content

Commit eeb9a0f

Browse files
relationalcopybara-github
authored andcommitted
Integrate soft rounding ops and layers into tensorflow_compression.
PiperOrigin-RevId: 338018642 Change-Id: Icffe8ebe4cc03c0de580d584568abe900805e88c
1 parent fa8ac89 commit eeb9a0f

File tree

11 files changed

+901
-2
lines changed

11 files changed

+901
-2
lines changed

tensorflow_compression/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
raise RuntimeError(
2323
"For tensorflow_compression, please install TensorFlow 2.1.")
2424

25-
2625
# pylint: disable=wildcard-import
2726
from tensorflow_compression.python.distributions.deep_factorized import *
2827
from tensorflow_compression.python.distributions.helpers import *
28+
from tensorflow_compression.python.distributions.round_adapters import *
2929
from tensorflow_compression.python.distributions.uniform_noise import *
3030
from tensorflow_compression.python.entropy_models.continuous_batched import *
3131
from tensorflow_compression.python.entropy_models.continuous_indexed import *
@@ -34,9 +34,12 @@
3434
from tensorflow_compression.python.layers.initializers import *
3535
from tensorflow_compression.python.layers.parameterizers import *
3636
from tensorflow_compression.python.layers.signal_conv import *
37+
from tensorflow_compression.python.layers.soft_round import *
3738
from tensorflow_compression.python.ops.math_ops import *
3839
from tensorflow_compression.python.ops.padding_ops import *
3940
from tensorflow_compression.python.ops.range_coding_ops import *
41+
from tensorflow_compression.python.ops.soft_round_ops import *
4042
from tensorflow_compression.python.ops.spectral_ops import *
4143
from tensorflow_compression.python.util.packed_tensors import *
44+
4245
# pylint: enable=wildcard-import

tensorflow_compression/python/all_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525
from tensorflow_compression.python.layers.gdn_test import *
2626
from tensorflow_compression.python.layers.parameterizers_test import *
2727
from tensorflow_compression.python.layers.signal_conv_test import *
28+
from tensorflow_compression.python.layers.soft_round_test import *
2829

2930
from tensorflow_compression.python.ops.math_ops_test import *
3031
from tensorflow_compression.python.ops.padding_ops_test import *
3132
from tensorflow_compression.python.ops.range_coding_ops_test import *
33+
from tensorflow_compression.python.ops.round_ops_test import *
3234
from tensorflow_compression.python.ops.spectral_ops_test import *
3335

3436
from tensorflow_compression.python.util.packed_tensors_test import *

tensorflow_compression/python/distributions/BUILD

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ py_library(
1111
deps = [
1212
":deep_factorized",
1313
":helpers",
14+
":round_adapters",
1415
":uniform_noise",
1516
],
1617
)
@@ -65,6 +66,28 @@ py_test(
6566
],
6667
)
6768

69+
py_library(
70+
name = "round_adapters",
71+
srcs = ["round_adapters.py"],
72+
srcs_version = "PY3",
73+
deps = [
74+
":deep_factorized",
75+
":helpers",
76+
":uniform_noise",
77+
"//tensorflow_compression/python/ops:soft_round_ops",
78+
],
79+
)
80+
81+
py_test(
82+
name = "round_adapters_test",
83+
srcs = ["round_adapters_test.py"],
84+
python_version = "PY3",
85+
deps = [
86+
":deep_factorized",
87+
":round_adapters",
88+
],
89+
)
90+
6891
filegroup(
6992
name = "py_src",
7093
srcs = glob(["*.py"]),
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# Copyright 2020 Google LLC. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Distribution adapters for (soft) round functions."""
16+
import tensorflow as tf
17+
import tensorflow_probability as tfp
18+
19+
from tensorflow_compression.python.distributions import deep_factorized
20+
from tensorflow_compression.python.distributions import helpers
21+
from tensorflow_compression.python.distributions import uniform_noise
22+
from tensorflow_compression.python.ops import soft_round_ops
23+
24+
25+
__all__ = [
26+
"MonotonicAdapter", "RoundAdapter", "NoisyRoundedNormal",
27+
"NoisyRoundedDeepFactorized", "SoftRoundAdapter", "NoisySoftRoundedNormal",
28+
"NoisySoftRoundedDeepFactorized"
29+
]
30+
31+
32+
class MonotonicAdapter(tfp.distributions.Distribution):
33+
"""Adapt a continuous distribution via an ascending monotonic function.
34+
35+
This is described in Appendix E. in the paper
36+
> "Universally Quantized Neural Compression"<br />
37+
> Eirikur Agustsson & Lucas Theis<br />
38+
> https://arxiv.org/abs/2006.09952
39+
40+
"""
41+
42+
invertible = True # Set to false if the transform is not invertible.
43+
44+
def __init__(self, base, name="MonotonicAdapter"):
45+
"""Initializer.
46+
47+
Arguments:
48+
base: A `tfp.distributions.Distribution` object representing a
49+
continuous-valued random variable.
50+
name: String. A name for this distribution.
51+
"""
52+
parameters = dict(locals())
53+
self._base = base
54+
super().__init__(
55+
dtype=base.dtype,
56+
reparameterization_type=base.reparameterization_type,
57+
validate_args=base.validate_args,
58+
allow_nan_stats=base.allow_nan_stats,
59+
parameters=parameters,
60+
name=name,
61+
)
62+
63+
@property
64+
def base(self):
65+
"""The base distribution."""
66+
return self._base
67+
68+
def transform(self, x):
69+
"""The forward transform."""
70+
raise NotImplementedError()
71+
72+
def inverse_transform(self, y):
73+
"""The backward transform."""
74+
# Let f(x) = self.transform(x)
75+
# Then g(y) = self.inverse_transform(y) is defined as
76+
# g(y) := inf_x { x : f(x) >= y }
77+
# which is just the inverse of `f` if it is invertible.
78+
raise NotImplementedError()
79+
80+
def _batch_shape_tensor(self):
81+
return self.base.batch_shape_tensor()
82+
83+
def _batch_shape(self):
84+
return self.base.batch_shape
85+
86+
def _event_shape_tensor(self):
87+
return self.base.event_shape_tensor()
88+
89+
def _event_shape(self):
90+
return self.base.event_shape
91+
92+
def _sample_n(self, n, seed=None):
93+
with tf.name_scope("round"):
94+
n = tf.convert_to_tensor(n, name="n")
95+
samples = self.base.sample(n, seed=seed)
96+
return self.transform(samples)
97+
98+
def _prob(self, *args, **kwargs):
99+
raise NotImplementedError
100+
101+
def _log_prob(self, *args, **kwargs):
102+
raise NotImplementedError
103+
104+
# pylint: disable=protected-access
105+
def _cdf(self, y):
106+
# Let f be the forward transform and g the inverse.
107+
# Then we have:
108+
# P( f(x) <= y )
109+
# P( g(f(x)) <= g(y) )
110+
# = P( x <= g(y) )
111+
return self.base._cdf(self.inverse_transform(y))
112+
113+
def _log_cdf(self, y):
114+
return self.base._log_cdf(self.inverse_transform(y))
115+
116+
def _survival_function(self, y):
117+
return self.base._survival_function(self.inverse_transform(y))
118+
119+
def _log_survival_function(self, y):
120+
return self.base._log_survival_function(self.inverse_transform(y))
121+
122+
def _quantile(self, value):
123+
if not self.invertible:
124+
raise NotImplementedError()
125+
# We have:
126+
# P( x <= z ) = value
127+
# if and only if
128+
# P( f(x) <= f(z) ) = value
129+
return self.transform(self.base._quantile(value))
130+
131+
def _mode(self):
132+
# Same logic as for _quantile.
133+
if not self.invertible:
134+
raise NotImplementedError()
135+
return self.transform(self.base._mode())
136+
137+
def _quantization_offset(self):
138+
# Same logic as for _quantile.
139+
if not self.invertible:
140+
raise NotImplementedError()
141+
return self.transform(helpers.quantization_offset(self.base))
142+
143+
def _lower_tail(self, tail_mass):
144+
# Same logic as for _quantile.
145+
if not self.invertible:
146+
raise NotImplementedError()
147+
return self.transform(helpers.lower_tail(self.base, tail_mass))
148+
149+
def _upper_tail(self, tail_mass):
150+
# Same logic as for _quantile.
151+
if not self.invertible:
152+
raise NotImplementedError()
153+
return self.transform(helpers.upper_tail(self.base, tail_mass))
154+
# pylint: enable=protected-access
155+
156+
157+
class RoundAdapter(MonotonicAdapter):
158+
"""Continuous density function + round."""
159+
160+
invertible = False
161+
162+
def transform(self, x):
163+
return tf.round(x)
164+
165+
def inverse_transform(self, y):
166+
# Let f(x) = round(x)
167+
# Then g(y) = inverse_transform(y) is defined as
168+
# g(y) := inf_x { x : f(x) >= y }
169+
# For f = round, we have
170+
# round(x) >= y
171+
# <=> round(x) >= ceil(y)
172+
# so g(y) = inf_x { x: round(x) >= ceil(y) }
173+
# = ceil(y)-0.5
174+
175+
# Alternative derivation:
176+
# P( round(x) <= y )
177+
# = P( round(x) <= floor(y) )
178+
# = P( x <= floor(y)+0.5 )
179+
# = P( x <= ceil(y)-0.5 )
180+
# = P( x <= inverse_transform(y) )
181+
return tf.math.ceil(y) - 0.5
182+
183+
def _quantization_offset(self):
184+
return tf.convert_to_tensor(0.0, dtype=self.dtype)
185+
186+
def _lower_tail(self, tail_mass):
187+
return tf.math.floor(helpers.lower_tail(self.base, tail_mass))
188+
189+
def _upper_tail(self, tail_mass):
190+
return tf.math.ceil(helpers.upper_tail(self.base, tail_mass))
191+
192+
193+
class NoisyRoundAdapter(uniform_noise.UniformNoiseAdapter):
194+
"""Uniform noise + round."""
195+
196+
def __init__(self, base, name="NoisyRoundAdapter"):
197+
"""Initializer.
198+
199+
Arguments:
200+
base: A `tfp.distributions.Distribution` object representing a
201+
continuous-valued random variable.
202+
name: String. A name for this distribution.
203+
"""
204+
super().__init__(RoundAdapter(base), name=name)
205+
206+
207+
class NoisyRoundedDeepFactorized(NoisyRoundAdapter):
208+
"""Rounded DeepFactorized + uniform noise."""
209+
210+
def __init__(self, name="NoisyRoundedDeepFactorized", **kwargs):
211+
prior = deep_factorized.DeepFactorized(**kwargs)
212+
super().__init__(base=prior, name=name)
213+
214+
215+
class NoisyRoundedNormal(NoisyRoundAdapter):
216+
"""Rounded normal distribution + uniform noise."""
217+
218+
def __init__(self, name="NoisyRoundedNormal", **kwargs):
219+
super().__init__(base=tfp.distributions.Normal(**kwargs), name=name)
220+
221+
222+
class SoftRoundAdapter(MonotonicAdapter):
223+
"""Differentiable approximation to round."""
224+
225+
def __init__(self, base, alpha, name="SoftRoundAdapter"):
226+
"""Initializer.
227+
228+
Arguments:
229+
base: A `tfp.distributions.Distribution` object representing a
230+
continuous-valued random variable.
231+
alpha: Float or tf.Tensor. Controls smoothness of the approximation.
232+
name: String. A name for this distribution.
233+
"""
234+
super().__init__(base=base, name=name)
235+
self._alpha = alpha
236+
237+
def transform(self, x):
238+
return soft_round_ops.soft_round(x, self._alpha)
239+
240+
def inverse_transform(self, y):
241+
return soft_round_ops.soft_round_inverse(y, self._alpha)
242+
243+
244+
class NoisySoftRoundAdapter(uniform_noise.UniformNoiseAdapter):
245+
"""Uniform noise + differentiable approximation to round."""
246+
247+
def __init__(self, base, alpha, name="NoisySoftRoundAdapter"):
248+
"""Initializer.
249+
250+
Arguments:
251+
base: A `tfp.distributions.Distribution` object representing a
252+
continuous-valued random variable.
253+
alpha: Float or tf.Tensor. Controls smoothness of soft round.
254+
name: String. A name for this distribution.
255+
"""
256+
super().__init__(SoftRoundAdapter(base, alpha), name=name)
257+
258+
259+
class NoisySoftRoundedNormal(NoisySoftRoundAdapter):
260+
"""Soft rounded normal distribution + uniform noise."""
261+
262+
def __init__(self, alpha=5.0, name="NoisySoftRoundedNormal", **kwargs):
263+
super().__init__(
264+
base=tfp.distributions.Normal(**kwargs),
265+
alpha=alpha,
266+
name=name)
267+
268+
269+
class NoisySoftRoundedDeepFactorized(NoisySoftRoundAdapter):
270+
"""Soft rounded deep factorized distribution + uniform noise."""
271+
272+
def __init__(self,
273+
alpha=5.0,
274+
name="NoisySoftRoundedDeepFactorized",
275+
**kwargs):
276+
super().__init__(
277+
base=deep_factorized.DeepFactorized(**kwargs),
278+
alpha=alpha,
279+
name=name)

0 commit comments

Comments
 (0)