Skip to content

Commit 604f679

Browse files
Johannes Ballétensorflower-gardener
authored andcommitted
Adds EPR algorithm to TF-MOT.
PiperOrigin-RevId: 434513427
1 parent c8cce59 commit 604f679

File tree

4 files changed

+440
-0
lines changed

4 files changed

+440
-0
lines changed

ci/kokoro/run_bazel_unittests.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ set -o pipefail # Treat the failure of a command in a pipeline as error.
3333
# set -x
3434

3535
pip install --requirement "requirements.txt"
36+
# Not in list of requirements, but needed for EPR test:
37+
pip install tensorflow-compression
3638

3739
# Run the tests.
3840
# Some tests requiring more RAM that the CI machine provides are disabled.

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,30 @@ package(default_visibility = ["//visibility:private"])
55

66
licenses(["notice"])
77

8+
pytype_strict_library(
9+
name = "epr",
10+
srcs = ["epr.py"],
11+
srcs_version = "PY3",
12+
deps = [
13+
# tensorflow dep1,
14+
# tensorflow_compression dep1,
15+
"//tensorflow_model_optimization/python/core/common/keras/compression:algorithm",
16+
],
17+
)
18+
19+
py_strict_test(
20+
name = "epr_test",
21+
timeout = "long",
22+
srcs = ["epr_test.py"],
23+
python_version = "PY3",
24+
tags = ["requires-net:external"],
25+
deps = [
26+
":epr",
27+
# absl/testing:parameterized dep1,
28+
# tensorflow dep1,
29+
],
30+
)
31+
832
pytype_strict_library(
933
name = "same_training_and_inference",
1034
srcs = ["same_training_and_inference.py"],
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Entropy Penalized Reparameterization algorithm.
16+
17+
This is an implementation of the method described in:
18+
> "Scalable Model Compression by Entropy Penalized Reparameterization"<br />
19+
> D. Oktay, J. Ballé, S. Singh, A. Shrivastava<br />
20+
> https://arxiv.org/abs/1906.06624
21+
"""
22+
23+
import functools
24+
from typing import List
25+
import tensorflow as tf
26+
import tensorflow_compression as tfc
27+
from tensorflow_model_optimization.python.core.common.keras.compression import algorithm
28+
29+
30+
class EPR(algorithm.WeightCompressor):
31+
"""Defines how to apply the EPR algorithm."""
32+
33+
def __init__(self, entropy_penalty):
34+
self.entropy_penalty = entropy_penalty
35+
36+
def get_compressible_weights(self, original_layer):
37+
if isinstance(
38+
original_layer,
39+
(tf.keras.layers.Dense, tf.keras.layers.Conv1D, tf.keras.layers.Conv2D),
40+
):
41+
if original_layer.use_bias:
42+
return [original_layer.kernel, original_layer.bias]
43+
else:
44+
return [original_layer.kernel]
45+
return []
46+
47+
def init_training_weights(self, pretrained_weight: tf.Tensor):
48+
shape = pretrained_weight.shape
49+
dtype = pretrained_weight.dtype
50+
weight_name = "bias" if shape.rank == 1 else "kernel"
51+
52+
if 1 <= shape.rank <= 2:
53+
# Bias or dense kernel.
54+
prior_shape = []
55+
self.add_training_weight(
56+
name=weight_name,
57+
shape=pretrained_weight.shape,
58+
dtype=pretrained_weight.dtype,
59+
initializer=tf.keras.initializers.Constant(pretrained_weight))
60+
elif 3 <= shape.rank <= 4:
61+
# Convolution kernel.
62+
kernel_shape = tf.shape(pretrained_weight)
63+
if shape.rank == 3:
64+
kernel_rdft = tf.signal.rfft(
65+
tf.transpose(pretrained_weight, (1, 2, 0)))
66+
else:
67+
kernel_rdft = tf.signal.rfft2d(
68+
tf.transpose(pretrained_weight, (2, 3, 0, 1)))
69+
kernel_rdft = tf.stack(
70+
[tf.math.real(kernel_rdft), tf.math.imag(kernel_rdft)], axis=-1)
71+
prior_shape = tf.shape(kernel_rdft)[2:]
72+
kernel_rdft /= tf.sqrt(tf.cast(tf.reduce_prod(kernel_shape[:-2]), dtype))
73+
self.add_training_weight(
74+
name="kernel_rdft",
75+
shape=kernel_rdft.shape,
76+
dtype=kernel_rdft.dtype,
77+
initializer=tf.keras.initializers.Constant(kernel_rdft))
78+
self.add_training_weight(
79+
name="kernel_shape",
80+
shape=kernel_shape.shape,
81+
dtype=kernel_shape.dtype,
82+
# TODO(jballe): If False, breaks optimize.create_layer_for_training().
83+
# If True, throws warnings that int tensors have no gradient.
84+
# trainable=False,
85+
initializer=tf.keras.initializers.Constant(kernel_shape))
86+
else:
87+
raise ValueError(
88+
f"Expected bias or kernel tensor with rank between 1 and 4, received "
89+
f"shape {self._shape}.")
90+
91+
# Logarithm of quantization step size.
92+
log_step = tf.fill(prior_shape, tf.constant(-4, dtype=dtype))
93+
self.add_training_weight(
94+
name=f"{weight_name}_log_step",
95+
shape=log_step.shape,
96+
dtype=log_step.dtype,
97+
initializer=tf.keras.initializers.Constant(log_step))
98+
99+
# Logarithm of scale of prior.
100+
log_scale = tf.fill(prior_shape, tf.constant(2.5, dtype=dtype))
101+
self.add_training_weight(
102+
name=f"{weight_name}_log_scale",
103+
shape=log_scale.shape,
104+
dtype=log_scale.dtype,
105+
initializer=tf.keras.initializers.Constant(log_scale))
106+
107+
def project_training_weights(self, *training_weights) -> tf.Tensor:
108+
if len(training_weights) == 3:
109+
# Bias or dense kernel.
110+
weight, log_step, _ = training_weights
111+
step = tf.exp(log_step)
112+
return tfc.round_st(weight / step) * step
113+
else:
114+
# Convolution kernel.
115+
kernel_rdft, kernel_shape, log_step, _ = training_weights
116+
step = tf.exp(log_step)
117+
kernel_rdft = tfc.round_st(kernel_rdft / step)
118+
kernel_rdft *= step * tf.sqrt(
119+
tf.cast(tf.reduce_prod(kernel_shape[:-2]), kernel_rdft.dtype))
120+
kernel_rdft = tf.dtypes.complex(*tf.unstack(kernel_rdft, axis=-1))
121+
if kernel_rdft.shape.rank == 3:
122+
kernel = tf.signal.irfft(kernel_rdft, fft_length=kernel_shape[:-2])
123+
return tf.transpose(kernel, (2, 0, 1))
124+
else:
125+
kernel = tf.signal.irfft2d(kernel_rdft, fft_length=kernel_shape[:-2])
126+
return tf.transpose(kernel, (2, 3, 0, 1))
127+
128+
def compress_training_weights(
129+
self, *training_weights: tf.Tensor) -> List[tf.Tensor]:
130+
if len(training_weights) == 3:
131+
# Bias or dense kernel.
132+
weight, log_step, log_scale = training_weights
133+
weight_shape = tf.shape(weight)
134+
else:
135+
# Convolution kernel.
136+
weight, weight_shape, log_step, log_scale = training_weights
137+
prior = tfc.NoisyLogistic(loc=0., scale=tf.exp(log_scale))
138+
em = tfc.ContinuousBatchedEntropyModel(
139+
prior, coding_rank=weight.shape.rank,
140+
compression=True, stateless=True, offset_heuristic=False)
141+
string = em.compress(weight / tf.exp(log_step))
142+
weight_shape = tf.cast(weight_shape, tf.uint16)
143+
return [string, weight_shape, log_step, em.cdf, em.cdf_offset]
144+
145+
def decompress_weights(self, string, weight_shape, log_step,
146+
cdf, cdf_offset) -> tf.Tensor:
147+
weight_shape = tf.cast(weight_shape, tf.int32)
148+
if weight_shape.shape[0] <= 2:
149+
# Bias or dense kernel.
150+
em = tfc.ContinuousBatchedEntropyModel(
151+
prior_shape=log_step.shape, cdf=cdf, cdf_offset=cdf_offset,
152+
coding_rank=weight_shape.shape[0], compression=True, stateless=True,
153+
offset_heuristic=False)
154+
return em.decompress(string, weight_shape) * tf.exp(log_step)
155+
else:
156+
# Convolution kernel.
157+
em = tfc.ContinuousBatchedEntropyModel(
158+
prior_shape=log_step.shape, cdf=cdf, cdf_offset=cdf_offset,
159+
coding_rank=weight_shape.shape[0] + 1, compression=True,
160+
stateless=True, offset_heuristic=False)
161+
kernel_rdft = em.decompress(string, weight_shape[-2:])
162+
kernel_rdft *= tf.exp(log_step) * tf.sqrt(
163+
tf.cast(tf.reduce_prod(weight_shape[:-2]), kernel_rdft.dtype))
164+
kernel_rdft = tf.dtypes.complex(*tf.unstack(kernel_rdft, axis=-1))
165+
if weight_shape.shape[0] == 3:
166+
kernel = tf.signal.irfft(kernel_rdft, fft_length=weight_shape[:-2])
167+
return tf.transpose(kernel, (2, 0, 1))
168+
else:
169+
kernel = tf.signal.irfft2d(kernel_rdft, fft_length=weight_shape[:-2])
170+
return tf.transpose(kernel, (2, 3, 0, 1))
171+
172+
def compute_entropy(self, *training_weights) -> tf.Tensor:
173+
if len(training_weights) == 3:
174+
# Bias or dense kernel.
175+
weight, log_step, log_scale = training_weights
176+
else:
177+
# Convolution kernel.
178+
weight, _, log_step, log_scale = training_weights
179+
prior = tfc.NoisyLogistic(loc=0., scale=tf.exp(log_scale))
180+
em = tfc.ContinuousBatchedEntropyModel(
181+
prior, coding_rank=weight.shape.rank,
182+
compression=False, offset_heuristic=False)
183+
_, bits = em(weight / tf.exp(log_step), training=True)
184+
return bits
185+
186+
def get_training_model(self, model: tf.keras.Model) -> tf.keras.Model:
187+
"""Augments a model for training with EPR."""
188+
# pylint: disable=protected-access
189+
if (not isinstance(model, tf.keras.Sequential) and
190+
not model._is_graph_network):
191+
raise ValueError(
192+
"`compress_model` must be either a sequential or functional model.")
193+
# pylint: enable=protected-access
194+
195+
entropies = []
196+
197+
# Number of dimensions of original model weights. Used to bring
198+
# entropy_penalty into a more standardized range.
199+
weight_dims = tf.add_n([tf.size(w) for w in model.trainable_weights])
200+
201+
def create_layer_for_training(layer):
202+
if not layer.built:
203+
raise ValueError(
204+
"Applying EPR currently requires passing in a built model.")
205+
train_layer = algorithm.create_layer_for_training(layer, algorithm=self)
206+
train_layer.build(layer.input_shape)
207+
for name in train_layer.attr_name_map.values():
208+
entropy = functools.partial(
209+
self.compute_entropy, *train_layer.training_weights[name])
210+
entropies.append(entropy)
211+
return train_layer
212+
213+
def compute_entropy_loss():
214+
total_entropy = tf.add_n([e() for e in entropies])
215+
entropy_penalty = self.entropy_penalty / tf.cast(
216+
weight_dims, total_entropy.dtype)
217+
return total_entropy * entropy_penalty
218+
219+
training_model = tf.keras.models.clone_model(
220+
model, clone_function=create_layer_for_training)
221+
training_model.add_loss(compute_entropy_loss)
222+
223+
# TODO(jballe): It would be great to be able to track the entropy losses
224+
# combined during training. How to do this?
225+
# TODO(jballe): Some models might require training log_scale weights with a
226+
# different optimizer/learning rate. How to do this?
227+
return training_model
228+
229+
def compress_model(self, model: tf.keras.Model) -> tf.keras.Model:
230+
"""Compresses a model after training with EPR."""
231+
# pylint: disable=protected-access
232+
if (not isinstance(model, tf.keras.Sequential) and
233+
not model._is_graph_network):
234+
raise ValueError(
235+
"`compress_model` must be either a sequential or functional model.")
236+
# pylint: enable=protected-access
237+
238+
def create_layer_for_inference(layer):
239+
return algorithm.create_layer_for_inference(layer, algorithm=self)
240+
241+
return tf.keras.models.clone_model(
242+
model, clone_function=create_layer_for_inference)

0 commit comments

Comments
 (0)