Skip to content

Commit 18e4127

Browse files
author
Johannes Ballé
committed
Add SignalConv and GDN layer classes, plus dependencies.
1 parent 6fbfebb commit 18e4127

19 files changed

+2574
-40
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ For usage questions and discussions, please head over to our
1010
**Please note**: You need TensorFlow 1.9 (or the master branch as of May 2018)
1111
or later.
1212

13-
To make sure the library imports succeed, try running the two
14-
tests.
13+
To make sure the library imports succeed, try running the unit tests.
1514
```
16-
python tensorflow_compression/python/ops/coder_ops_test.py
17-
python tensorflow_compression/python/layers/entropybottleneck_test.py
15+
for i in tensorflow_compression/python/*/*_test.py; do
16+
python $i
17+
done
1818
```
1919

2020
## Entropy bottleneck layer

tensorflow_compression/__init__.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,26 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22-
# pylint: disable=wildcard-import
23-
from tensorflow_compression.python.layers.entropybottleneck import *
24-
from tensorflow.contrib.coder.python.ops.coder_ops import *
25-
# pylint: enable=wildcard-import
22+
# Dependency imports
2623

2724
from tensorflow.python.util.all_util import remove_undocumented
25+
26+
# pylint: disable=wildcard-import,g-bad-import-order
27+
from tensorflow.contrib.coder.python.ops.coder_ops import *
28+
from tensorflow_compression.python.layers.entropy_models import *
29+
from tensorflow_compression.python.layers.gdn import *
30+
from tensorflow_compression.python.layers.initializers import *
31+
from tensorflow_compression.python.layers.parameterizers import *
32+
from tensorflow_compression.python.layers.signal import *
33+
from tensorflow_compression.python.ops.math_ops import *
34+
from tensorflow_compression.python.ops.padding_ops import *
35+
from tensorflow_compression.python.ops.spectral_ops import *
36+
# pylint: enable=wildcard-import,g-bad-import-order
37+
2838
remove_undocumented(__name__, [
29-
"EntropyBottleneck",
39+
"EntropyBottleneck", "GDN", "IdentityInitializer", "Parameterizer",
40+
"StaticParameterizer", "RDFTParameterizer", "NonnegativeParameterizer",
41+
"SignalConv1D", "SignalConv2D", "SignalConv3D",
42+
"upper_bound", "lower_bound", "same_padding_for_kernel", "irdft_matrix",
3043
"pmf_to_quantized_cdf", "range_decode", "range_encode",
3144
])

tensorflow_compression/python/layers/entropybottleneck.py renamed to tensorflow_compression/python/layers/entropy_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22-
from tensorflow.contrib.coder.python.ops import coder_ops
22+
# Dependency imports
2323

2424
import numpy as np
2525

26+
from tensorflow.contrib.coder.python.ops import coder_ops
27+
2628
from tensorflow.python.eager import context
2729
from tensorflow.python.framework import constant_op
2830
from tensorflow.python.framework import dtypes

tensorflow_compression/python/layers/entropybottleneck_test.py renamed to tensorflow_compression/python/layers/entropy_models_test.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22-
import numpy as np
22+
# Dependency imports
2323

24-
from tensorflow_compression.python.layers import entropybottleneck
24+
import numpy as np
2525

2626
from tensorflow.python.framework import dtypes
2727
from tensorflow.python.ops import array_ops
@@ -30,13 +30,15 @@
3030
from tensorflow.python.platform import test
3131
from tensorflow.python.training import gradient_descent
3232

33+
from tensorflow_compression.python.layers import entropy_models
34+
3335

3436
class EntropyBottleneckTest(test.TestCase):
3537

3638
def test_noise(self):
3739
# Tests that the noise added is uniform noise between -0.5 and 0.5.
3840
inputs = array_ops.placeholder(dtypes.float32, (None, 1))
39-
layer = entropybottleneck.EntropyBottleneck()
41+
layer = entropy_models.EntropyBottleneck()
4042
noisy, _ = layer(inputs, training=True)
4143
with self.test_session() as sess:
4244
sess.run(variables.global_variables_initializer())
@@ -49,7 +51,7 @@ def test_quantization(self):
4951
# Tests that inputs are quantized to full integer values, even after
5052
# quantiles have been updated.
5153
inputs = array_ops.placeholder(dtypes.float32, (None, 1))
52-
layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=False)
54+
layer = entropy_models.EntropyBottleneck(optimize_integer_offset=False)
5355
quantized, _ = layer(inputs, training=False)
5456
opt = gradient_descent.GradientDescentOptimizer(learning_rate=1)
5557
self.assertTrue(len(layer.losses) == 1)
@@ -66,7 +68,7 @@ def test_quantization_optimized_offset(self):
6668
# have been updated. However, the difference between input and output should
6769
# be between -0.5 and 0.5, and the offset must be consistent.
6870
inputs = array_ops.placeholder(dtypes.float32, (None, 1))
69-
layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=True)
71+
layer = entropy_models.EntropyBottleneck(optimize_integer_offset=True)
7072
quantized, _ = layer(inputs, training=False)
7173
opt = gradient_descent.GradientDescentOptimizer(learning_rate=1)
7274
self.assertTrue(len(layer.losses) == 1)
@@ -85,7 +87,7 @@ def test_codec(self):
8587
# Tests that inputs are compressed and decompressed correctly, and quantized
8688
# to full integer values, even after quantiles have been updated.
8789
inputs = array_ops.placeholder(dtypes.float32, (1, None, 1))
88-
layer = entropybottleneck.EntropyBottleneck(
90+
layer = entropy_models.EntropyBottleneck(
8991
data_format="channels_last", init_scale=60,
9092
optimize_integer_offset=False)
9193
bitstrings = layer.compress(inputs)
@@ -108,7 +110,7 @@ def test_codec_optimized_offset(self):
108110
# However, the difference between input and output should be between -0.5
109111
# and 0.5, and the offset must be consistent.
110112
inputs = array_ops.placeholder(dtypes.float32, (1, None, 1))
111-
layer = entropybottleneck.EntropyBottleneck(
113+
layer = entropy_models.EntropyBottleneck(
112114
data_format="channels_last", init_scale=60,
113115
optimize_integer_offset=True)
114116
bitstrings = layer.compress(inputs)
@@ -132,7 +134,7 @@ def test_codec_clipping(self):
132134
# Tests that inputs are compressed and decompressed correctly, and clipped
133135
# to the expected range.
134136
inputs = array_ops.placeholder(dtypes.float32, (1, None, 1))
135-
layer = entropybottleneck.EntropyBottleneck(
137+
layer = entropy_models.EntropyBottleneck(
136138
data_format="channels_last", init_scale=40)
137139
bitstrings = layer.compress(inputs)
138140
decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:])
@@ -149,7 +151,7 @@ def test_channels_last(self):
149151
# Test the layer with more than one channel and multiple input dimensions,
150152
# with the channels in the last dimension.
151153
inputs = array_ops.placeholder(dtypes.float32, (None, None, None, 2))
152-
layer = entropybottleneck.EntropyBottleneck(
154+
layer = entropy_models.EntropyBottleneck(
153155
data_format="channels_last", init_scale=50)
154156
noisy, _ = layer(inputs, training=True)
155157
quantized, _ = layer(inputs, training=False)
@@ -170,7 +172,7 @@ def test_channels_first(self):
170172
# Test the layer with more than one channel and multiple input dimensions,
171173
# with the channel dimension right after the batch dimension.
172174
inputs = array_ops.placeholder(dtypes.float32, (None, 3, None, None))
173-
layer = entropybottleneck.EntropyBottleneck(
175+
layer = entropy_models.EntropyBottleneck(
174176
data_format="channels_first", init_scale=50)
175177
noisy, _ = layer(inputs, training=True)
176178
quantized, _ = layer(inputs, training=False)
@@ -192,7 +194,7 @@ def test_compress(self):
192194
# `test_decompress`. If you set the constant at the end to `True`, this test
193195
# will fail and the log will contain the new test data.
194196
inputs = array_ops.placeholder(dtypes.float32, (2, 3, 10))
195-
layer = entropybottleneck.EntropyBottleneck(
197+
layer = entropy_models.EntropyBottleneck(
196198
data_format="channels_first", filters=(), init_scale=2)
197199
bitstrings = layer.compress(inputs)
198200
decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:])
@@ -237,7 +239,7 @@ def test_decompress(self):
237239
bitstrings = array_ops.placeholder(dtypes.string)
238240
input_shape = array_ops.placeholder(dtypes.int32)
239241
quantized_cdf = array_ops.placeholder(dtypes.int32)
240-
layer = entropybottleneck.EntropyBottleneck(
242+
layer = entropy_models.EntropyBottleneck(
241243
data_format="channels_first", filters=(), dtype=dtypes.float32)
242244
layer.build(self.expected.shape)
243245
layer._quantized_cdf = quantized_cdf
@@ -253,13 +255,13 @@ def test_build_decompress(self):
253255
# Test that layer can be built when `decompress` is the first call to it.
254256
bitstrings = array_ops.placeholder(dtypes.string)
255257
input_shape = array_ops.placeholder(dtypes.int32, shape=[3])
256-
layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32)
258+
layer = entropy_models.EntropyBottleneck(dtype=dtypes.float32)
257259
layer.decompress(bitstrings, input_shape[1:], channels=5)
258260
self.assertTrue(layer.built)
259261

260262
def test_pmf_normalization(self):
261263
# Test that probability mass functions are normalized correctly.
262-
layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32)
264+
layer = entropy_models.EntropyBottleneck(dtype=dtypes.float32)
263265
layer.build((None, 10))
264266
with self.test_session() as sess:
265267
sess.run(variables.global_variables_initializer())
@@ -268,7 +270,7 @@ def test_pmf_normalization(self):
268270

269271
def test_visualize(self):
270272
# Test that summary op can be constructed.
271-
layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32)
273+
layer = entropy_models.EntropyBottleneck(dtype=dtypes.float32)
272274
layer.build((None, 10))
273275
summary = layer.visualize()
274276
with self.test_session() as sess:
@@ -278,7 +280,7 @@ def test_visualize(self):
278280
def test_normalization(self):
279281
# Test that densities are normalized correctly.
280282
inputs = array_ops.placeholder(dtypes.float32, (None, 1))
281-
layer = entropybottleneck.EntropyBottleneck(filters=(2,))
283+
layer = entropy_models.EntropyBottleneck(filters=(2,))
282284
_, likelihood = layer(inputs, training=True)
283285
with self.test_session() as sess:
284286
sess.run(variables.global_variables_initializer())
@@ -291,7 +293,7 @@ def test_normalization(self):
291293
def test_entropy_estimates(self):
292294
# Test that entropy estimates match actual range coding.
293295
inputs = array_ops.placeholder(dtypes.float32, (1, None, 1))
294-
layer = entropybottleneck.EntropyBottleneck(
296+
layer = entropy_models.EntropyBottleneck(
295297
filters=(2, 3), data_format="channels_last")
296298
_, likelihood = layer(inputs, training=True)
297299
diff_entropy = math_ops.reduce_sum(math_ops.log(likelihood)) / -np.log(2)

0 commit comments

Comments
 (0)