|
| 1 | +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""Wrapper which applies quantization operations over underlying layer. |
| 16 | +
|
| 17 | + `QuantizeWrapper` is responsible for modifying the construction of the |
| 18 | + underlying layer to ensure proper quantization operations are placed in the |
| 19 | + graph. |
| 20 | +
|
| 21 | + These operations ensure proper introduction of inference time losses during |
| 22 | + training. |
| 23 | +""" |
| 24 | + |
| 25 | +from __future__ import absolute_import |
| 26 | +from __future__ import division |
| 27 | +from __future__ import print_function |
| 28 | + |
| 29 | +from tensorflow.python.framework import dtypes |
| 30 | +from tensorflow.python.keras import backend as K |
| 31 | +from tensorflow.python.keras import initializers |
| 32 | +from tensorflow.python.keras.layers.wrappers import Wrapper |
| 33 | +from tensorflow.python.keras.utils import tf_utils |
| 34 | + |
| 35 | +from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation |
| 36 | +from tensorflow_model_optimization.python.core.quantization.keras import quantize_provider as quantize_provider_mod |
| 37 | + |
| 38 | + |
| 39 | +class QuantizeWrapper(Wrapper): |
| 40 | + """Quantizes the weights and activations of the keras layer it wraps.""" |
| 41 | + |
| 42 | + def __init__(self, layer, quantize_provider, **kwargs): |
| 43 | + """Create a quantize emulate wrapper for a keras layer. |
| 44 | +
|
| 45 | + Args: |
| 46 | + layer: The keras layer to be quantized. |
| 47 | + quantize_provider: `QuantizeProvider` to quantize layer. |
| 48 | + **kwargs: Additional keyword arguments to be passed to the keras layer. |
| 49 | + """ |
| 50 | + |
| 51 | + if quantize_provider is None: |
| 52 | + raise ValueError('quantize_provider cannot be None. It is needed to ' |
| 53 | + 'quantize a layer.') |
| 54 | + |
| 55 | + super(QuantizeWrapper, self).__init__(layer, **kwargs) |
| 56 | + self.quantize_provider = quantize_provider |
| 57 | + |
| 58 | + # Ensures cloning of already built layer works. |
| 59 | + if (not hasattr(self, '_batch_input_shape') and |
| 60 | + hasattr(layer, '_batch_input_shape')): |
| 61 | + self._batch_input_shape = self.layer._batch_input_shape # pylint: disable=protected-access |
| 62 | + self._track_trackable(layer, name='layer') |
| 63 | + |
| 64 | + @staticmethod |
| 65 | + def _weight_name(name): |
| 66 | + """Extracts the weight name from the full TensorFlow variable name. |
| 67 | +
|
| 68 | + For example, returns 'kernel' for 'dense_2/kernel:0'. |
| 69 | +
|
| 70 | + Args: |
| 71 | + name: TensorFlow variable name. |
| 72 | +
|
| 73 | + Returns: |
| 74 | + Extracted weight name. |
| 75 | + """ |
| 76 | + return name.split(':')[0].split('/')[-1] |
| 77 | + |
| 78 | + def _add_range_weights(self, name): |
| 79 | + min_weight = self.add_weight( |
| 80 | + name + '_min', initializer=initializers.Constant(-6.0), trainable=False) |
| 81 | + max_weight = self.add_weight( |
| 82 | + name + '_max', initializer=initializers.Constant(6.0), trainable=False) |
| 83 | + |
| 84 | + return min_weight, max_weight |
| 85 | + |
| 86 | + def build(self, input_shape): |
| 87 | + super(QuantizeWrapper, self).build(input_shape) |
| 88 | + |
| 89 | + self.optimizer_step = self.add_weight( |
| 90 | + 'optimizer_step', |
| 91 | + initializer=initializers.Constant(-1), |
| 92 | + dtype=dtypes.int32, |
| 93 | + trainable=False) |
| 94 | + |
| 95 | + self._weight_vars = [] |
| 96 | + for weight, quantizer in \ |
| 97 | + self.quantize_provider.get_weights_and_quantizers(self.layer): |
| 98 | + min_var, max_var = self._add_range_weights(self._weight_name(weight.name)) |
| 99 | + |
| 100 | + self._weight_vars.append((weight, quantizer, min_var, max_var)) |
| 101 | + # Needed to ensure unquantized weights get trained as part of the wrapper. |
| 102 | + self._trainable_weights.append(weight) |
| 103 | + |
| 104 | + self._quantize_activations = [] |
| 105 | + for activation, quantizer in \ |
| 106 | + self.quantize_provider.get_activations_and_quantizers(self.layer): |
| 107 | + quantize_activation = quantize_aware_activation.QuantizeAwareActivation( |
| 108 | + activation, quantizer, self.optimizer_step, self) |
| 109 | + |
| 110 | + self._quantize_activations.append(quantize_activation) |
| 111 | + |
| 112 | + def compute_output_shape(self, input_shape): |
| 113 | + return self.layer.compute_output_shape(self.layer.input_shape) |
| 114 | + |
| 115 | + def _dict_vars(self, min_var, max_var): |
| 116 | + return {'min_var': min_var, 'max_var': max_var} |
| 117 | + |
| 118 | + def call(self, inputs, training=None): |
| 119 | + if training is None: |
| 120 | + training = K.learning_phase() |
| 121 | + |
| 122 | + # Quantize all weights, and replace them in the underlying layer. |
| 123 | + |
| 124 | + quantized_weights = [] |
| 125 | + for unquantized_weight, quantizer, min_var, max_var in self._weight_vars: |
| 126 | + |
| 127 | + def make_quantizer_fn(training): |
| 128 | + """Use currying to return True/False specialized fns to the cond.""" |
| 129 | + |
| 130 | + def quantizer_fn(unquantized_weight=unquantized_weight, |
| 131 | + quantizer=quantizer, |
| 132 | + min_var=min_var, |
| 133 | + max_var=max_var): |
| 134 | + return quantizer(unquantized_weight, self.optimizer_step, training, |
| 135 | + **self._dict_vars(min_var, max_var)) |
| 136 | + |
| 137 | + return quantizer_fn |
| 138 | + |
| 139 | + quantized_weight = tf_utils.smart_cond( |
| 140 | + training, make_quantizer_fn(True), make_quantizer_fn(False)) |
| 141 | + quantized_weights.append(quantized_weight) |
| 142 | + |
| 143 | + self.quantize_provider.set_quantize_weights(self.layer, quantized_weights) |
| 144 | + |
| 145 | + # Replace all activations with `QuantizeAwareActivation`s which can |
| 146 | + # quantize activation tensors during graph construction. |
| 147 | + |
| 148 | + for quantize_activation in self._quantize_activations: |
| 149 | + quantize_activation.training = training |
| 150 | + |
| 151 | + self.quantize_provider.set_quantize_activations( |
| 152 | + self.layer, self._quantize_activations) |
| 153 | + |
| 154 | + return self.layer.call(inputs) |
| 155 | + |
| 156 | + def get_config(self): |
| 157 | + base_config = super(QuantizeWrapper, self).get_config() |
| 158 | + config = {'quantize_provider': self.quantize_provider} |
| 159 | + return dict(list(base_config.items()) + list(config.items())) |
| 160 | + |
| 161 | + @classmethod |
| 162 | + def from_config(cls, config): |
| 163 | + config = config.copy() |
| 164 | + |
| 165 | + quantize_provider = config.pop('quantize_provider') |
| 166 | + from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object # pylint: disable=g-import-not-at-top |
| 167 | + # TODO(pulkitb): Add all known `QuantizeProvider`s to custom_objects |
| 168 | + custom_objects = { |
| 169 | + 'QuantizeProvider': quantize_provider_mod.QuantizeProvider |
| 170 | + } |
| 171 | + config['quantize_provider'] = deserialize_keras_object( |
| 172 | + quantize_provider, |
| 173 | + module_objects=globals(), |
| 174 | + custom_objects=custom_objects) |
| 175 | + |
| 176 | + from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top |
| 177 | + layer = deserialize_layer(config.pop('layer')) |
| 178 | + config['layer'] = layer |
| 179 | + |
| 180 | + return cls(**config) |
| 181 | + |
| 182 | + @property |
| 183 | + def trainable(self): |
| 184 | + return self.layer.trainable |
| 185 | + |
| 186 | + @trainable.setter |
| 187 | + def trainable(self, value): |
| 188 | + self.layer.trainable = value |
| 189 | + |
| 190 | + @property |
| 191 | + def trainable_weights(self): |
| 192 | + return self.layer.trainable_weights + self._trainable_weights |
| 193 | + |
| 194 | + @property |
| 195 | + def non_trainable_weights(self): |
| 196 | + return self.layer.non_trainable_weights + self._non_trainable_weights |
| 197 | + |
| 198 | + @property |
| 199 | + def updates(self): |
| 200 | + return self.layer.updates + self._updates |
| 201 | + |
| 202 | + @property |
| 203 | + def losses(self): |
| 204 | + return self.layer.losses + self._losses |
0 commit comments