Skip to content

Commit 72366f5

Browse files
Johannes Ballécopybara-github
authored andcommitted
Refactors and updates models to TF2 API/Keras.
Adds support for training/validating on CLIC dataset. Simplifies PackedTensors, and fixes two bugs with serialization of activation functions and layers that have not been built. PiperOrigin-RevId: 361232886 Change-Id: I1ebffe4e89fcfeb899994f0bd9931fb3ac8641b4
1 parent ef48d38 commit 72366f5

File tree

10 files changed

+1226
-979
lines changed

10 files changed

+1226
-979
lines changed

models/bls2017.py

Lines changed: 316 additions & 251 deletions
Large diffs are not rendered by default.

models/bmshj2018.py

Lines changed: 399 additions & 352 deletions
Large diffs are not rendered by default.

models/ms2020.py

Lines changed: 441 additions & 293 deletions
Large diffs are not rendered by default.

tensorflow_compression/python/entropy_models/continuous_batched.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(self,
9797
`compress()` and `decompress()` will be built on instantiation. If set
9898
to `False`, these two methods will not be accessible.
9999
laplace_tail_mass: Float. If positive, will augment the prior with a
100-
laplace mixture for training stability. (experimental)
100+
Laplace mixture for training stability. (experimental)
101101
expected_grads: If True, will use analytical expected gradients during
102102
backpropagation w.r.t. additive uniform noise.
103103
tail_mass: Float. Approximate probability mass which is range encoded with
@@ -164,7 +164,6 @@ def _compute_indexes_and_offset(self, broadcast_shape):
164164
def __call__(self, bottleneck, training=True):
165165
"""Perturbs a tensor with (quantization) noise and estimates bitcost.
166166
167-
168167
Args:
169168
bottleneck: `tf.Tensor` containing the data to be compressed. Must have at
170169
least `self.coding_rank` dimensions, and the innermost dimensions must
@@ -280,9 +279,8 @@ def decompress(self, strings, broadcast_shape):
280279
Args:
281280
strings: `tf.Tensor` containing the compressed bit strings.
282281
broadcast_shape: Iterable of ints. The part of the output tensor shape
283-
between the shape of `strings` on the left and
284-
`self.prior_shape` on the right. This must match the shape
285-
of the input to `compress()`.
282+
between the shape of `strings` on the left and `self.prior_shape` on the
283+
right. This must match the shape of the input to `compress()`.
286284
287285
Returns:
288286
A `tf.Tensor` of shape `strings.shape + broadcast_shape +

tensorflow_compression/python/layers/gdn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,14 +423,16 @@ def get_config(self) -> Dict[str, Any]:
423423

424424
# Since alpha and epsilon are scalar, allow fixed values to be serialized.
425425
def try_serialize(parameter, name):
426+
if parameter is None:
427+
return None
426428
try:
427429
return tf.keras.utils.serialize_keras_object(parameter)
428430
except (ValueError, TypeError): # Should throw TypeError, but doesn't...
429431
try:
430432
return float(parameter)
431433
except TypeError:
432434
raise TypeError(
433-
f"Can't serialize {name} of type '{type(parameter)}'.")
435+
f"Can't serialize {name} of type {type(parameter)}.")
434436

435437
alpha_parameter = try_serialize(
436438
self.alpha_parameter, "alpha_parameter")

tensorflow_compression/python/layers/gdn_test.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ def test_variables_receive_gradients(self):
146146
weight_shapes = [tuple(w.shape) for w in layer.trainable_weights]
147147
self.assertSameElements(grad_shapes, weight_shapes)
148148

149-
def test_can_be_saved_within_functional_model(self):
149+
@parameterized.parameters(False, True)
150+
def test_can_be_saved_within_functional_model(self, build):
150151
inputs = tf.keras.Input(shape=(5,))
151152
outputs = gdn.GDN()(inputs)
152153
model = tf.keras.Model(inputs=inputs, outputs=outputs)
@@ -161,12 +162,13 @@ def test_can_be_saved_within_functional_model(self):
161162
self.assertIsInstance(layer.epsilon_parameter, tf.Tensor)
162163
self.assertEmpty(layer.epsilon_parameter.shape)
163164

164-
x = tf.random.uniform((5, 5), dtype=tf.float32)
165-
y = model(x)
166-
weight_names = [w.name for w in model.weights]
165+
if build:
166+
x = tf.random.uniform((5, 5), dtype=tf.float32)
167+
y = model(x)
168+
weight_names = [w.name for w in model.weights]
167169

168170
tempdir = self.create_tempdir()
169-
model_path = os.path.join(tempdir.full_path, "model")
171+
model_path = os.path.join(tempdir, "model")
170172
# This should force the model to be reconstructed via configs.
171173
model.save(model_path, save_traces=False)
172174

@@ -182,11 +184,12 @@ def test_can_be_saved_within_functional_model(self):
182184
self.assertIsInstance(layer.epsilon_parameter, tf.Tensor)
183185
self.assertEmpty(layer.epsilon_parameter.shape)
184186

185-
with self.subTest(name="model_outputs_identical"):
186-
self.assertAllEqual(model(x), y)
187+
if build:
188+
with self.subTest(name="model_outputs_identical"):
189+
self.assertAllEqual(model(x), y)
187190

188-
with self.subTest(name="model_weights_identical"):
189-
self.assertSameElements(weight_names, [w.name for w in model.weights])
191+
with self.subTest(name="model_weights_identical"):
192+
self.assertSameElements(weight_names, [w.name for w in model.weights])
190193

191194

192195
if __name__ == "__main__":

tensorflow_compression/python/layers/signal_conv.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def activation(self) -> Optional[Callable[[Any], tf.Tensor]]:
439439
@activation.setter
440440
def activation(self, value):
441441
self._check_not_built()
442-
self._activation = value
442+
self._activation = tf.keras.activations.get(value)
443443

444444
@property
445445
def use_bias(self) -> bool:
@@ -938,7 +938,7 @@ def call(self, inputs) -> tf.Tensor:
938938

939939
# Finally, pass through activation function if requested.
940940
if self.activation is not None:
941-
outputs = self.activation(outputs) # pylint:disable=not-callable
941+
outputs = self.activation(outputs)
942942

943943
return outputs
944944

@@ -979,13 +979,15 @@ def get_config(self) -> Dict[str, Any]:
979979
# Special-case variables, which can't be serialized but are handled by
980980
# get_weights()/set_weights().
981981
def try_serialize(parameter, name):
982+
if isinstance(parameter, str):
983+
return parameter
982984
try:
983985
return tf.keras.utils.serialize_keras_object(parameter)
984986
except (ValueError, TypeError): # Should throw TypeError, but doesn't...
985987
if isinstance(parameter, tf.Variable):
986988
return "variable"
987989
raise TypeError(
988-
f"Can't serialize {name} of type '{type(parameter)}'.")
990+
f"Can't serialize {name} of type {type(parameter)}.")
989991

990992
kernel_parameter = try_serialize(self.kernel_parameter, "kernel")
991993
bias_parameter = try_serialize(self.bias_parameter, "bias")
@@ -1000,7 +1002,7 @@ def try_serialize(parameter, name):
10001002
extra_pad_end=self.extra_pad_end,
10011003
channel_separable=self.channel_separable,
10021004
data_format=self.data_format,
1003-
activation=self.activation,
1005+
activation=tf.keras.activations.serialize(self.activation),
10041006
use_bias=self.use_bias,
10051007
use_explicit=self.use_explicit,
10061008
kernel_parameter=kernel_parameter,

tensorflow_compression/python/layers/signal_conv_test.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Tests of signal processing convolution layers."""
1616

1717
import os
18+
from absl.testing import parameterized
1819
import numpy as np
1920
import scipy.signal
2021
import tensorflow as tf
@@ -23,7 +24,7 @@
2324
from tensorflow_compression.python.layers import signal_conv
2425

2526

26-
class SignalConvTest(tf.test.TestCase):
27+
class SignalConvTest(tf.test.TestCase, parameterized.TestCase):
2728

2829
def test_invalid_data_format_raises_error(self):
2930
with self.assertRaises(ValueError):
@@ -112,9 +113,11 @@ def test_variables_receive_gradients(self):
112113
weight_shapes = [tuple(w.shape) for w in layer.trainable_weights]
113114
self.assertSameElements(grad_shapes, weight_shapes)
114115

115-
def test_can_be_saved_within_functional_model(self):
116+
@parameterized.parameters(False, True)
117+
def test_can_be_saved_within_functional_model(self, build):
116118
inputs = tf.keras.Input(shape=(None, 2))
117-
outputs = signal_conv.SignalConv1D(1, 3, use_bias=True)(inputs)
119+
outputs = signal_conv.SignalConv1D(
120+
1, 3, use_bias=True, activation=tf.nn.relu)(inputs)
118121
model = tf.keras.Model(inputs=inputs, outputs=outputs)
119122
layer = model.get_layer("signal_conv1d")
120123

@@ -123,12 +126,13 @@ def test_can_be_saved_within_functional_model(self):
123126
self.assertIsInstance(layer.kernel_parameter, parameters.RDFTParameter)
124127
self.assertIsInstance(layer.bias_parameter, tf.Variable)
125128

126-
x = tf.random.uniform((1, 5, 2), dtype=tf.float32)
127-
y = model(x)
128-
weight_names = [w.name for w in model.weights]
129+
if build:
130+
x = tf.random.uniform((1, 5, 2), dtype=tf.float32)
131+
y = model(x)
132+
weight_names = [w.name for w in model.weights]
129133

130134
tempdir = self.create_tempdir()
131-
model_path = os.path.join(tempdir.full_path, "model")
135+
model_path = os.path.join(tempdir, "model")
132136
# This should force the model to be reconstructed via configs.
133137
model.save(model_path, save_traces=False)
134138

@@ -140,11 +144,12 @@ def test_can_be_saved_within_functional_model(self):
140144
self.assertIsInstance(layer.kernel_parameter, parameters.RDFTParameter)
141145
self.assertIsInstance(layer.bias_parameter, tf.Variable)
142146

143-
with self.subTest(name="model_outputs_identical"):
144-
self.assertAllEqual(model(x), y)
147+
if build:
148+
with self.subTest(name="model_outputs_identical"):
149+
self.assertAllEqual(model(x), y)
145150

146-
with self.subTest(name="model_weights_identical"):
147-
self.assertSameElements(weight_names, [w.name for w in model.weights])
151+
with self.subTest(name="model_weights_identical"):
152+
self.assertSameElements(weight_names, [w.name for w in model.weights])
148153

149154

150155
class ConvolutionsTest(tf.test.TestCase):
@@ -353,7 +358,7 @@ def run_or_fail(self, method,
353358
except:
354359
msg = []
355360
for k in sorted(args):
356-
msg.append("{}={}".format(k, args[k]))
361+
msg.append(f"{k}={args[k]}")
357362
print("Failed when it shouldn't have: " + ", ".join(msg))
358363
raise
359364
else:
@@ -363,7 +368,7 @@ def run_or_fail(self, method,
363368
except:
364369
msg = []
365370
for k in sorted(args):
366-
msg.append("{}={}".format(k, args[k]))
371+
msg.append(f"{k}={args[k]}")
367372
print("Did not fail when it should have: " + ", ".join(msg))
368373
raise
369374

tensorflow_compression/python/util/packed_tensors.py

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# ==============================================================================
1515
"""Packed tensors in bit sequences."""
1616

17-
import numpy as np
1817
import tensorflow as tf
1918

2019

@@ -62,50 +61,36 @@ def string(self):
6261
def string(self, value):
6362
self._example.ParseFromString(value)
6463

65-
def pack(self, tensors, arrays):
64+
def pack(self, tensors):
6665
"""Packs `Tensor` values into this object."""
67-
if len(tensors) != len(arrays):
68-
raise ValueError("`tensors` and `arrays` must have same length.")
6966
i = 1
70-
for tensor, array in zip(tensors, arrays):
67+
for tensor in tensors:
7168
feature = self._example.features.feature[chr(i)]
7269
feature.Clear()
73-
if array.ndim != 1:
74-
raise RuntimeError("Unexpected tensor rank: {}.".format(array.ndim))
70+
if tensor.shape.rank != 1:
71+
raise RuntimeError(f"Unexpected tensor rank: {tensor.shape.rank}.")
7572
if tensor.dtype.is_integer:
76-
feature.int64_list.value[:] = array
73+
feature.int64_list.value[:] = tensor.numpy()
7774
elif tensor.dtype == tf.string:
78-
feature.bytes_list.value[:] = array
75+
feature.bytes_list.value[:] = tensor.numpy()
7976
else:
80-
raise RuntimeError(
81-
"Unexpected tensor dtype: '{}'.".format(tensor.dtype))
77+
raise RuntimeError(f"Unexpected tensor dtype: '{tensor.dtype}'.")
8278
i += 1
8379
# Delete any remaining, previously set arrays.
8480
while chr(i) in self._example.features.feature:
8581
del self._example.features.feature[chr(i)]
8682
i += 1
8783

88-
# TODO(jonycgn): Remove this function once all models are converted.
89-
def unpack(self, tensors):
90-
"""Unpacks `Tensor` values from this object."""
91-
# Check tensor dtype first for a more informative error message.
92-
for x in tensors:
93-
if not x.dtype.is_integer and x.dtype != tf.string:
94-
raise RuntimeError("Unexpected tensor dtype: '{}'.".format(x.dtype))
95-
96-
# Extact numpy dtypes and call type-based API.
97-
np_dtypes = [x.dtype.as_numpy_dtype for x in tensors]
98-
return self.unpack_from_np_dtypes(np_dtypes)
99-
100-
def unpack_from_np_dtypes(self, np_dtypes):
101-
"""Unpacks values from this object based on numpy dtypes."""
102-
arrays = []
103-
for i, np_dtype in enumerate(np_dtypes):
84+
def unpack(self, dtypes):
85+
"""Unpacks values from this object based on dtypes."""
86+
tensors = []
87+
for i, dtype in enumerate(dtypes):
88+
dtype = tf.as_dtype(dtype)
10489
feature = self._example.features.feature[chr(i + 1)]
105-
if np.issubdtype(np_dtype, np.integer):
106-
arrays.append(np.array(feature.int64_list.value, dtype=np_dtype))
107-
elif np_dtype == np.dtype(object) or np.issubdtype(np_dtype, np.bytes_):
108-
arrays.append(np.array(feature.bytes_list.value, dtype=np_dtype))
90+
if dtype.is_integer:
91+
tensors.append(tf.constant(feature.int64_list.value, dtype=dtype))
92+
elif dtype == tf.string:
93+
tensors.append(tf.constant(feature.bytes_list.value, dtype=dtype))
10994
else:
110-
raise RuntimeError("Unexpected numpy dtype: '{}'.".format(np_dtype))
111-
return arrays
95+
raise RuntimeError(f"Unexpected dtype: '{dtype}'.")
96+
return tensors

tensorflow_compression/python/util/packed_tensors_test.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,25 @@
1414
# ==============================================================================
1515
"""Tests of PackedTensors class."""
1616

17-
import numpy as np
1817
import tensorflow as tf
1918
from tensorflow_compression.python.util import packed_tensors
2019

2120

2221
class PackedTensorsTest(tf.test.TestCase):
2322

24-
def test_pack_unpack(self):
25-
"""Tests packing and unpacking tensors."""
26-
string = np.array(["xyz".encode("ascii")], dtype=object)
27-
shape = np.array([1, 3], dtype=np.int32)
28-
arrays = [string, shape]
29-
30-
string_t = tf.zeros([1], dtype=tf.string)
31-
shape_t = tf.zeros([2], dtype=tf.int32)
32-
tensors = [string_t, shape_t]
33-
23+
def test_pack_unpack_identity(self):
24+
"""Tests packing and unpacking tensors returns the same values."""
25+
string = tf.constant(["xyz"], dtype=tf.string)
26+
shape = tf.constant([1, 3], dtype=tf.int32)
3427
packed = packed_tensors.PackedTensors()
35-
packed.pack(tensors, arrays)
28+
packed.pack([string, shape])
3629
packed = packed_tensors.PackedTensors(packed.string)
37-
string_u, shape_u = packed.unpack(tensors)
38-
39-
self.assertAllEqual(string_u, string)
40-
self.assertAllEqual(shape_u, shape)
30+
string_unpacked, shape_unpacked = packed.unpack([tf.string, tf.int32])
31+
self.assertAllEqual(string_unpacked, string)
32+
self.assertAllEqual(shape_unpacked, shape)
4133

42-
def test_model(self):
43-
"""Tests setting and getting model."""
34+
def test_set_get_model_identity(self):
35+
"""Tests setting and getting model returns the same value."""
4436
packed = packed_tensors.PackedTensors()
4537
packed.model = "xyz"
4638
packed = packed_tensors.PackedTensors(packed.string)

0 commit comments

Comments
 (0)