|
| 1 | +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""Entropy Penalized Reparameterization algorithm. |
| 16 | +
|
| 17 | +This is an implementation of the method described in: |
| 18 | +> "Scalable Model Compression by Entropy Penalized Reparameterization"<br /> |
| 19 | +> D. Oktay, J. Ballé, S. Singh, A. Shrivastava<br /> |
| 20 | +> https://arxiv.org/abs/1906.06624 |
| 21 | +""" |
| 22 | + |
| 23 | +import functools |
| 24 | +from typing import List |
| 25 | +import tensorflow as tf |
| 26 | +import tensorflow_compression as tfc |
| 27 | +from tensorflow_model_optimization.python.core.common.keras.compression import algorithm |
| 28 | + |
| 29 | + |
| 30 | +class EPR(algorithm.WeightCompressor): |
| 31 | + """Defines how to apply the EPR algorithm.""" |
| 32 | + |
| 33 | + def __init__(self, entropy_penalty): |
| 34 | + self.entropy_penalty = entropy_penalty |
| 35 | + |
| 36 | + def get_compressible_weights(self, original_layer): |
| 37 | + if isinstance( |
| 38 | + original_layer, |
| 39 | + (tf.keras.layers.Dense, tf.keras.layers.Conv1D, tf.keras.layers.Conv2D), |
| 40 | + ): |
| 41 | + if original_layer.use_bias: |
| 42 | + return [original_layer.kernel, original_layer.bias] |
| 43 | + else: |
| 44 | + return [original_layer.kernel] |
| 45 | + return [] |
| 46 | + |
| 47 | + def init_training_weights(self, pretrained_weight: tf.Tensor): |
| 48 | + shape = pretrained_weight.shape |
| 49 | + dtype = pretrained_weight.dtype |
| 50 | + weight_name = "bias" if shape.rank == 1 else "kernel" |
| 51 | + |
| 52 | + if 1 <= shape.rank <= 2: |
| 53 | + # Bias or dense kernel. |
| 54 | + prior_shape = [] |
| 55 | + self.add_training_weight( |
| 56 | + name=weight_name, |
| 57 | + shape=pretrained_weight.shape, |
| 58 | + dtype=pretrained_weight.dtype, |
| 59 | + initializer=tf.keras.initializers.Constant(pretrained_weight)) |
| 60 | + elif 3 <= shape.rank <= 4: |
| 61 | + # Convolution kernel. |
| 62 | + kernel_shape = tf.shape(pretrained_weight) |
| 63 | + if shape.rank == 3: |
| 64 | + kernel_rdft = tf.signal.rfft( |
| 65 | + tf.transpose(pretrained_weight, (1, 2, 0))) |
| 66 | + else: |
| 67 | + kernel_rdft = tf.signal.rfft2d( |
| 68 | + tf.transpose(pretrained_weight, (2, 3, 0, 1))) |
| 69 | + kernel_rdft = tf.stack( |
| 70 | + [tf.math.real(kernel_rdft), tf.math.imag(kernel_rdft)], axis=-1) |
| 71 | + prior_shape = tf.shape(kernel_rdft)[2:] |
| 72 | + kernel_rdft /= tf.sqrt(tf.cast(tf.reduce_prod(kernel_shape[:-2]), dtype)) |
| 73 | + self.add_training_weight( |
| 74 | + name="kernel_rdft", |
| 75 | + shape=kernel_rdft.shape, |
| 76 | + dtype=kernel_rdft.dtype, |
| 77 | + initializer=tf.keras.initializers.Constant(kernel_rdft)) |
| 78 | + self.add_training_weight( |
| 79 | + name="kernel_shape", |
| 80 | + shape=kernel_shape.shape, |
| 81 | + dtype=kernel_shape.dtype, |
| 82 | + # TODO(jballe): If False, breaks optimize.create_layer_for_training(). |
| 83 | + # If True, throws warnings that int tensors have no gradient. |
| 84 | + # trainable=False, |
| 85 | + initializer=tf.keras.initializers.Constant(kernel_shape)) |
| 86 | + else: |
| 87 | + raise ValueError( |
| 88 | + f"Expected bias or kernel tensor with rank between 1 and 4, received " |
| 89 | + f"shape {self._shape}.") |
| 90 | + |
| 91 | + # Logarithm of quantization step size. |
| 92 | + log_step = tf.fill(prior_shape, tf.constant(-4, dtype=dtype)) |
| 93 | + self.add_training_weight( |
| 94 | + name=f"{weight_name}_log_step", |
| 95 | + shape=log_step.shape, |
| 96 | + dtype=log_step.dtype, |
| 97 | + initializer=tf.keras.initializers.Constant(log_step)) |
| 98 | + |
| 99 | + # Logarithm of scale of prior. |
| 100 | + log_scale = tf.fill(prior_shape, tf.constant(2.5, dtype=dtype)) |
| 101 | + self.add_training_weight( |
| 102 | + name=f"{weight_name}_log_scale", |
| 103 | + shape=log_scale.shape, |
| 104 | + dtype=log_scale.dtype, |
| 105 | + initializer=tf.keras.initializers.Constant(log_scale)) |
| 106 | + |
| 107 | + def project_training_weights(self, *training_weights) -> tf.Tensor: |
| 108 | + if len(training_weights) == 3: |
| 109 | + # Bias or dense kernel. |
| 110 | + weight, log_step, _ = training_weights |
| 111 | + step = tf.exp(log_step) |
| 112 | + return tfc.round_st(weight / step) * step |
| 113 | + else: |
| 114 | + # Convolution kernel. |
| 115 | + kernel_rdft, kernel_shape, log_step, _ = training_weights |
| 116 | + step = tf.exp(log_step) |
| 117 | + kernel_rdft = tfc.round_st(kernel_rdft / step) |
| 118 | + kernel_rdft *= step * tf.sqrt( |
| 119 | + tf.cast(tf.reduce_prod(kernel_shape[:-2]), kernel_rdft.dtype)) |
| 120 | + kernel_rdft = tf.dtypes.complex(*tf.unstack(kernel_rdft, axis=-1)) |
| 121 | + if kernel_rdft.shape.rank == 3: |
| 122 | + kernel = tf.signal.irfft(kernel_rdft, fft_length=kernel_shape[:-2]) |
| 123 | + return tf.transpose(kernel, (2, 0, 1)) |
| 124 | + else: |
| 125 | + kernel = tf.signal.irfft2d(kernel_rdft, fft_length=kernel_shape[:-2]) |
| 126 | + return tf.transpose(kernel, (2, 3, 0, 1)) |
| 127 | + |
| 128 | + def compress_training_weights( |
| 129 | + self, *training_weights: tf.Tensor) -> List[tf.Tensor]: |
| 130 | + if len(training_weights) == 3: |
| 131 | + # Bias or dense kernel. |
| 132 | + weight, log_step, log_scale = training_weights |
| 133 | + weight_shape = tf.shape(weight) |
| 134 | + else: |
| 135 | + # Convolution kernel. |
| 136 | + weight, weight_shape, log_step, log_scale = training_weights |
| 137 | + prior = tfc.NoisyLogistic(loc=0., scale=tf.exp(log_scale)) |
| 138 | + em = tfc.ContinuousBatchedEntropyModel( |
| 139 | + prior, coding_rank=weight.shape.rank, |
| 140 | + compression=True, stateless=True, offset_heuristic=False) |
| 141 | + string = em.compress(weight / tf.exp(log_step)) |
| 142 | + weight_shape = tf.cast(weight_shape, tf.uint16) |
| 143 | + return [string, weight_shape, log_step, em.cdf, em.cdf_offset] |
| 144 | + |
| 145 | + def decompress_weights(self, string, weight_shape, log_step, |
| 146 | + cdf, cdf_offset) -> tf.Tensor: |
| 147 | + weight_shape = tf.cast(weight_shape, tf.int32) |
| 148 | + if weight_shape.shape[0] <= 2: |
| 149 | + # Bias or dense kernel. |
| 150 | + em = tfc.ContinuousBatchedEntropyModel( |
| 151 | + prior_shape=log_step.shape, cdf=cdf, cdf_offset=cdf_offset, |
| 152 | + coding_rank=weight_shape.shape[0], compression=True, stateless=True, |
| 153 | + offset_heuristic=False) |
| 154 | + return em.decompress(string, weight_shape) * tf.exp(log_step) |
| 155 | + else: |
| 156 | + # Convolution kernel. |
| 157 | + em = tfc.ContinuousBatchedEntropyModel( |
| 158 | + prior_shape=log_step.shape, cdf=cdf, cdf_offset=cdf_offset, |
| 159 | + coding_rank=weight_shape.shape[0] + 1, compression=True, |
| 160 | + stateless=True, offset_heuristic=False) |
| 161 | + kernel_rdft = em.decompress(string, weight_shape[-2:]) |
| 162 | + kernel_rdft *= tf.exp(log_step) * tf.sqrt( |
| 163 | + tf.cast(tf.reduce_prod(weight_shape[:-2]), kernel_rdft.dtype)) |
| 164 | + kernel_rdft = tf.dtypes.complex(*tf.unstack(kernel_rdft, axis=-1)) |
| 165 | + if weight_shape.shape[0] == 3: |
| 166 | + kernel = tf.signal.irfft(kernel_rdft, fft_length=weight_shape[:-2]) |
| 167 | + return tf.transpose(kernel, (2, 0, 1)) |
| 168 | + else: |
| 169 | + kernel = tf.signal.irfft2d(kernel_rdft, fft_length=weight_shape[:-2]) |
| 170 | + return tf.transpose(kernel, (2, 3, 0, 1)) |
| 171 | + |
| 172 | + def compute_entropy(self, *training_weights) -> tf.Tensor: |
| 173 | + if len(training_weights) == 3: |
| 174 | + # Bias or dense kernel. |
| 175 | + weight, log_step, log_scale = training_weights |
| 176 | + else: |
| 177 | + # Convolution kernel. |
| 178 | + weight, _, log_step, log_scale = training_weights |
| 179 | + prior = tfc.NoisyLogistic(loc=0., scale=tf.exp(log_scale)) |
| 180 | + em = tfc.ContinuousBatchedEntropyModel( |
| 181 | + prior, coding_rank=weight.shape.rank, |
| 182 | + compression=False, offset_heuristic=False) |
| 183 | + _, bits = em(weight / tf.exp(log_step), training=True) |
| 184 | + return bits |
| 185 | + |
| 186 | + def get_training_model(self, model: tf.keras.Model) -> tf.keras.Model: |
| 187 | + """Augments a model for training with EPR.""" |
| 188 | + # pylint: disable=protected-access |
| 189 | + if (not isinstance(model, tf.keras.Sequential) and |
| 190 | + not model._is_graph_network): |
| 191 | + raise ValueError( |
| 192 | + "`compress_model` must be either a sequential or functional model.") |
| 193 | + # pylint: enable=protected-access |
| 194 | + |
| 195 | + entropies = [] |
| 196 | + |
| 197 | + # Number of dimensions of original model weights. Used to bring |
| 198 | + # entropy_penalty into a more standardized range. |
| 199 | + weight_dims = tf.add_n([tf.size(w) for w in model.trainable_weights]) |
| 200 | + |
| 201 | + def create_layer_for_training(layer): |
| 202 | + if not layer.built: |
| 203 | + raise ValueError( |
| 204 | + "Applying EPR currently requires passing in a built model.") |
| 205 | + train_layer = algorithm.create_layer_for_training(layer, algorithm=self) |
| 206 | + train_layer.build(layer.input_shape) |
| 207 | + for name in train_layer.attr_name_map.values(): |
| 208 | + entropy = functools.partial( |
| 209 | + self.compute_entropy, *train_layer.training_weights[name]) |
| 210 | + entropies.append(entropy) |
| 211 | + return train_layer |
| 212 | + |
| 213 | + def compute_entropy_loss(): |
| 214 | + total_entropy = tf.add_n([e() for e in entropies]) |
| 215 | + entropy_penalty = self.entropy_penalty / tf.cast( |
| 216 | + weight_dims, total_entropy.dtype) |
| 217 | + return total_entropy * entropy_penalty |
| 218 | + |
| 219 | + training_model = tf.keras.models.clone_model( |
| 220 | + model, clone_function=create_layer_for_training) |
| 221 | + training_model.add_loss(compute_entropy_loss) |
| 222 | + |
| 223 | + # TODO(jballe): It would be great to be able to track the entropy losses |
| 224 | + # combined during training. How to do this? |
| 225 | + # TODO(jballe): Some models might require training log_scale weights with a |
| 226 | + # different optimizer/learning rate. How to do this? |
| 227 | + return training_model |
| 228 | + |
| 229 | + def compress_model(self, model: tf.keras.Model) -> tf.keras.Model: |
| 230 | + """Compresses a model after training with EPR.""" |
| 231 | + # pylint: disable=protected-access |
| 232 | + if (not isinstance(model, tf.keras.Sequential) and |
| 233 | + not model._is_graph_network): |
| 234 | + raise ValueError( |
| 235 | + "`compress_model` must be either a sequential or functional model.") |
| 236 | + # pylint: enable=protected-access |
| 237 | + |
| 238 | + def create_layer_for_inference(layer): |
| 239 | + return algorithm.create_layer_for_inference(layer, algorithm=self) |
| 240 | + |
| 241 | + return tf.keras.models.clone_model( |
| 242 | + model, clone_function=create_layer_for_inference) |
0 commit comments