|
| 1 | +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""Registry responsible for built-in keras classes.""" |
| 16 | + |
| 17 | +import tensorflow as tf |
| 18 | + |
| 19 | +from tensorflow_model_optimization.python.core.quantization.keras import quant_ops |
| 20 | +from tensorflow_model_optimization.python.core.quantization.keras import quantizers |
| 21 | +from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import ( |
| 22 | + default_8bit_quantizers) |
| 23 | + |
| 24 | +layers = tf.keras.layers |
| 25 | + |
| 26 | + |
| 27 | +class _PrunePreserveInfo(object): |
| 28 | + """PrunePreserveInfo.""" |
| 29 | + def __init__(self, weight_attrs, quantize_config_attrs): |
| 30 | + """PrunePreserveInfo. |
| 31 | +
|
| 32 | + Args: |
| 33 | + weight_attrs: list of sparsity preservable weight attributes of layer. |
| 34 | + quantize_config_attrs: list of quantization configuration class name. |
| 35 | + """ |
| 36 | + self.weight_attrs = weight_attrs |
| 37 | + self.quantize_config_attrs = quantize_config_attrs |
| 38 | + |
| 39 | + |
| 40 | +class PrunePreserveQuantizeRegistry(object): |
| 41 | + """PrunePreserveQuantizeRegistry responsible for built-in keras layers.""" |
| 42 | + |
| 43 | + # The keys represent built-in keras layers; the first values represent the |
| 44 | + # the variables within the layers which hold the kernel weights, second |
| 45 | + # values represent the class name of quantization configuration for layers. |
| 46 | + # This decide the weights of layers with quantization configurations are |
| 47 | + # sparsity preservable. |
| 48 | + _LAYERS_CONFIG_MAP = { |
| 49 | + layers.Conv2D: |
| 50 | + _PrunePreserveInfo(['kernel'], ['Default8BitConvQuantizeConfig']), |
| 51 | + layers.Dense: |
| 52 | + _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']), |
| 53 | + |
| 54 | + # layers that supported with prune, but not yet with qat |
| 55 | + # layers.Conv1D: |
| 56 | + # _PrunePreserveInfo(['kernel'], []), |
| 57 | + # layers.Conv2DTranspose: |
| 58 | + # _PrunePreserveInfo(['kernel'], []), |
| 59 | + # layers.Conv3D: |
| 60 | + # _PrunePreserveInfo(['kernel'], []), |
| 61 | + # layers.Conv3DTranspose: |
| 62 | + # _PrunePreserveInfo(['kernel'], []), |
| 63 | + # layers.LocallyConnected1D: |
| 64 | + # _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']), |
| 65 | + # layers.LocallyConnected2D: |
| 66 | + # _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']), |
| 67 | + |
| 68 | + # DepthwiseCon2D is supported with 8bit qat, but not with prune |
| 69 | + # layers.DepthwiseConv2D: |
| 70 | + # _PrunePreserveInfo(['depthwise_kernel'], ['Default8BitConvQuantizeConfig']), |
| 71 | + |
| 72 | + # SeparableConv need verify from 8bit qat |
| 73 | + # layers.SeparableConv1D: |
| 74 | + # _PrunePreserveInfo(['pointwise_kernel'], ['Default8BitConvQuantizeConfig']), |
| 75 | + # layers.SeparableConv2D: |
| 76 | + # _PrunePreserveInfo(['pointwise_kernel'], ['Default8BitConvQuantizeConfig']), |
| 77 | + |
| 78 | + # Embedding need verify from 8bit qat |
| 79 | + # layers.Embedding: _PrunePreserveInfo(['embeddings'], []), |
| 80 | + } |
| 81 | + |
| 82 | + def __init__(self): |
| 83 | + |
| 84 | + self._config_quantizer_map = { |
| 85 | + 'Default8BitQuantizeConfig': |
| 86 | + PrunePerserveDefault8BitWeightsQuantizer(), |
| 87 | + 'Default8BitConvQuantizeConfig': |
| 88 | + PrunePerserveDefault8BitConvWeightsQuantizer(), |
| 89 | + } |
| 90 | + |
| 91 | + @classmethod |
| 92 | + def _no_trainable_weights(cls, layer): |
| 93 | + """Returns whether this layer has trainable weights. |
| 94 | +
|
| 95 | + Args: |
| 96 | + layer: The layer to check for trainable weights. |
| 97 | +
|
| 98 | + Returns: |
| 99 | + True/False whether the layer has trainable weights. |
| 100 | + """ |
| 101 | + |
| 102 | + return len(layer.trainable_weights) == 0 |
| 103 | + |
| 104 | + @classmethod |
| 105 | + def supports(cls, layer): |
| 106 | + """Returns whether the registry supports this layer type. |
| 107 | +
|
| 108 | + Args: |
| 109 | + layer: The layer to check for support. |
| 110 | +
|
| 111 | + Returns: |
| 112 | + True/False whether the layer type is supported. |
| 113 | +
|
| 114 | + """ |
| 115 | + |
| 116 | + # layers without trainable weights are consider supported, |
| 117 | + # e.g., ReLU, Softmax, and AveragePooling2D. |
| 118 | + if cls._no_trainable_weights(layer): |
| 119 | + return True |
| 120 | + |
| 121 | + if layer.__class__ in cls._LAYERS_CONFIG_MAP: |
| 122 | + return True |
| 123 | + |
| 124 | + return False |
| 125 | + |
| 126 | + @classmethod |
| 127 | + def _weight_names(cls, layer): |
| 128 | + |
| 129 | + if cls._no_trainable_weights(layer): |
| 130 | + return [] |
| 131 | + |
| 132 | + return cls._LAYERS_CONFIG_MAP[layer.__class__].weight_attrs |
| 133 | + |
| 134 | + @classmethod |
| 135 | + def get_sparsity_preservable_weights(cls, layer): |
| 136 | + """Get sparsity preservable weights from keras layer |
| 137 | +
|
| 138 | + Args: |
| 139 | + layer: instance of keras layer |
| 140 | +
|
| 141 | + Returns: |
| 142 | + List of sparsity preservable weights |
| 143 | + """ |
| 144 | + return [getattr(layer, weight) for weight in cls._weight_names(layer)] |
| 145 | + |
| 146 | + @classmethod |
| 147 | + def get_suppport_quantize_config_names(cls, layer): |
| 148 | + """Get class name of supported quantize config for layer |
| 149 | +
|
| 150 | + Args: |
| 151 | + layer: instance of keras layer |
| 152 | +
|
| 153 | + Returns: |
| 154 | + List of supported quantize config class name. |
| 155 | + """ |
| 156 | + |
| 157 | + # layers without trainable weights don't need quantize_config for pqat |
| 158 | + if cls._no_trainable_weights(layer): |
| 159 | + return [] |
| 160 | + |
| 161 | + return cls._LAYERS_CONFIG_MAP[layer.__class__].quantize_config_attrs |
| 162 | + |
| 163 | + def apply_sparsity_preserve_quantize_config(self, layer, quantize_config): |
| 164 | + """ apply weights sparsity preservation |
| 165 | +
|
| 166 | + Args: |
| 167 | + layer: The layer to check for support. |
| 168 | + quantize_config: quantization config to check for support, |
| 169 | + apply sparsity preservation to pruned weights |
| 170 | +
|
| 171 | + Returns: |
| 172 | + Returns quantize_config with addon sparsity preserve weight_quantizer. |
| 173 | + """ |
| 174 | + if self.supports(layer): |
| 175 | + if self._no_trainable_weights(layer): |
| 176 | + return quantize_config |
| 177 | + if (quantize_config.__class__.__name__ |
| 178 | + in self._LAYERS_CONFIG_MAP[layer.__class__].quantize_config_attrs): |
| 179 | + quantize_config.weight_quantizer = self._config_quantizer_map[ |
| 180 | + quantize_config.__class__.__name__] |
| 181 | + else: |
| 182 | + raise ValueError('Configuration ' + |
| 183 | + str(quantize_config.__class__.__name__) + |
| 184 | + ' is not supported for Layer ' + |
| 185 | + str(layer.__class__) + '.') |
| 186 | + else: |
| 187 | + raise ValueError('Layer ' + str(layer.__class__) + ' is not supported.') |
| 188 | + |
| 189 | + return quantize_config |
| 190 | + |
| 191 | + |
| 192 | +class PrunePerserveDefaultWeightsQuantizer(quantizers.LastValueQuantizer): |
| 193 | + """Quantize weights while preserve sparsity.""" |
| 194 | + def __init__(self, num_bits, per_axis, symmetric, narrow_range): |
| 195 | + """PrunePerserveDefaultWeightsQuantizer |
| 196 | +
|
| 197 | + Args: |
| 198 | + num_bits: Number of bits for quantization |
| 199 | + per_axis: Whether to apply per_axis quantization. The last dimension is |
| 200 | + used as the axis. |
| 201 | + symmetric: If true, use symmetric quantization limits instead of training |
| 202 | + the minimum and maximum of each quantization range separately. |
| 203 | + narrow_range: In case of 8 bits, narrow_range nudges the quantized range |
| 204 | + to be [-127, 127] instead of [-128, 127]. This ensures symmetric |
| 205 | + range has 0 as the centre. |
| 206 | + """ |
| 207 | + |
| 208 | + super(PrunePerserveDefaultWeightsQuantizer, self).__init__( |
| 209 | + num_bits=num_bits, |
| 210 | + per_axis=per_axis, |
| 211 | + symmetric=symmetric, |
| 212 | + narrow_range=narrow_range, |
| 213 | + ) |
| 214 | + |
| 215 | + def _build_sparsity_mask(self, name, layer): |
| 216 | + weights = getattr(layer.layer, name) |
| 217 | + sparsity_mask = tf.math.divide_no_nan(weights, weights) |
| 218 | + |
| 219 | + return {'sparsity_mask': sparsity_mask} |
| 220 | + |
| 221 | + def build(self, tensor_shape, name, layer): |
| 222 | + """Construct mask to preserve weights sparsity. |
| 223 | +
|
| 224 | + Args: |
| 225 | + tensor_shape: Shape of weights which needs to be quantized. |
| 226 | + name: Name of weights in layer. |
| 227 | + layer: quantization wrapped keras layer. |
| 228 | +
|
| 229 | + Returns: Dictionary of constructed sparsity mask and |
| 230 | + quantization params, the dictionary will be passed |
| 231 | + to __call__ function. |
| 232 | + """ |
| 233 | + result = self._build_sparsity_mask(name, layer) |
| 234 | + result.update( |
| 235 | + super(PrunePerserveDefaultWeightsQuantizer, |
| 236 | + self).build(tensor_shape, name, layer)) |
| 237 | + return result |
| 238 | + |
| 239 | + def __call__(self, inputs, training, weights, **kwargs): |
| 240 | + """Apply sparsity preserved quantization to the input tensor. |
| 241 | +
|
| 242 | + Args: |
| 243 | + inputs: Input tensor (layer's weights) to be quantized. |
| 244 | + training: Whether the graph is currently training. |
| 245 | + weights: Dictionary of weights (params) the quantizer can use to |
| 246 | + quantize the tensor (layer's weights). This contains the weights |
| 247 | + created in the `build` function. |
| 248 | + **kwargs: Additional variables which may be passed to the quantizer. |
| 249 | +
|
| 250 | + Returns: quantized tensor. |
| 251 | + """ |
| 252 | + |
| 253 | + prune_preserve_inputs = tf.multiply(inputs, weights['sparsity_mask']) |
| 254 | + |
| 255 | + return quant_ops.LastValueQuantize( |
| 256 | + prune_preserve_inputs, |
| 257 | + weights['min_var'], |
| 258 | + weights['max_var'], |
| 259 | + is_training=training, |
| 260 | + num_bits=self.num_bits, |
| 261 | + per_channel=self.per_axis, |
| 262 | + symmetric=self.symmetric, |
| 263 | + narrow_range=self.narrow_range, |
| 264 | + ) |
| 265 | + |
| 266 | + |
| 267 | +class PrunePerserveDefault8BitWeightsQuantizer( |
| 268 | + PrunePerserveDefaultWeightsQuantizer): |
| 269 | + """PrunePerserveWeightsQuantizer for default 8bit weights""" |
| 270 | + def __init__(self): |
| 271 | + super(PrunePerserveDefault8BitWeightsQuantizer, |
| 272 | + self).__init__(num_bits=8, |
| 273 | + per_axis=False, |
| 274 | + symmetric=True, |
| 275 | + narrow_range=True) |
| 276 | + |
| 277 | + |
| 278 | +class PrunePerserveDefault8BitConvWeightsQuantizer( |
| 279 | + PrunePerserveDefaultWeightsQuantizer, |
| 280 | + default_8bit_quantizers.Default8BitConvWeightsQuantizer, |
| 281 | +): |
| 282 | + """PrunePerserveWeightsQuantizer for default 8bit Conv2D/DepthwiseConv2D weights""" |
| 283 | + def __init__(self): |
| 284 | + default_8bit_quantizers.Default8BitConvWeightsQuantizer.__init__(self) |
| 285 | + |
| 286 | + def build(self, tensor_shape, name, layer): |
| 287 | + result = PrunePerserveDefaultWeightsQuantizer._build_sparsity_mask( |
| 288 | + self, name, layer) |
| 289 | + result.update( |
| 290 | + default_8bit_quantizers.Default8BitConvWeightsQuantizer.build( |
| 291 | + self, tensor_shape, name, layer)) |
| 292 | + return result |
0 commit comments