Skip to content

Commit e39363d

Browse files
committed
Enable prune preserve quantization aware training (pqat)
Implemented prune preserve quant-aware training. Implemented pqat registry for supported keras layers. Implemented pqat weights quantizers. Example with simple cnn model on mnist. Added unit tests around the new changes. experimental API with tfmot.quantization.keras.quantize_apply() Change-Id: I3646c7d219399ac3191ae34b18a18db3801a8825 simplify PQAT registry, remove comments regarding private API. Change-Id: I71f5aa5ef8f3b2254e646a3022656a56514ed6e4
1 parent 3a0c22d commit e39363d

File tree

9 files changed

+660
-0
lines changed

9 files changed

+660
-0
lines changed

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ py_library(
1717
"//tensorflow_model_optimization/python/core/quantization/keras/graph_transformations",
1818
"//tensorflow_model_optimization/python/core/quantization/keras/layers",
1919
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit",
20+
"//tensorflow_model_optimization/python/core/quantization/keras/prune_preserve",
2021
],
2122
)
2223

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package(default_visibility = [
2+
"//tensorflow_model_optimization:__subpackages__",
3+
])
4+
5+
licenses(["notice"]) # Apache 2.0
6+
7+
py_library(
8+
name = "prune_preserve",
9+
srcs = [
10+
"__init__.py",
11+
],
12+
srcs_version = "PY3",
13+
deps = [],
14+
)
15+
16+
py_library(
17+
name = "prune_preserve_quantize_registry",
18+
srcs = [
19+
"prune_preserve_quantize_registry.py",
20+
],
21+
srcs_version = "PY3",
22+
deps = [
23+
# tensorflow dep1,
24+
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantizers",
25+
],
26+
)
27+
28+
py_test(
29+
name = "prune_preserve_quantize_registry_test",
30+
srcs = [
31+
"prune_preserve_quantize_registry_test.py",
32+
],
33+
python_version = "PY3",
34+
deps = [
35+
":prune_preserve_quantize_registry",
36+
# tensorflow dep1,
37+
"//tensorflow_model_optimization/python/core/sparsity/keras:prune_registry",
38+
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
39+
]
40+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Registry responsible for built-in keras classes."""
16+
17+
import tensorflow as tf
18+
19+
from tensorflow_model_optimization.python.core.quantization.keras import quant_ops
20+
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
21+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import (
22+
default_8bit_quantizers)
23+
24+
layers = tf.keras.layers
25+
26+
27+
class _PrunePreserveInfo(object):
28+
"""PrunePreserveInfo."""
29+
def __init__(self, weight_attrs, quantize_config_attrs):
30+
"""PrunePreserveInfo.
31+
32+
Args:
33+
weight_attrs: list of sparsity preservable weight attributes of layer.
34+
quantize_config_attrs: list of quantization configuration class name.
35+
"""
36+
self.weight_attrs = weight_attrs
37+
self.quantize_config_attrs = quantize_config_attrs
38+
39+
40+
class PrunePreserveQuantizeRegistry(object):
41+
"""PrunePreserveQuantizeRegistry responsible for built-in keras layers."""
42+
43+
# The keys represent built-in keras layers; the first values represent the
44+
# the variables within the layers which hold the kernel weights, second
45+
# values represent the class name of quantization configuration for layers.
46+
# This decide the weights of layers with quantization configurations are
47+
# sparsity preservable.
48+
_LAYERS_CONFIG_MAP = {
49+
layers.Conv2D:
50+
_PrunePreserveInfo(['kernel'], ['Default8BitConvQuantizeConfig']),
51+
layers.Dense:
52+
_PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
53+
54+
# layers that supported with prune, but not yet with qat
55+
# layers.Conv1D:
56+
# _PrunePreserveInfo(['kernel'], []),
57+
# layers.Conv2DTranspose:
58+
# _PrunePreserveInfo(['kernel'], []),
59+
# layers.Conv3D:
60+
# _PrunePreserveInfo(['kernel'], []),
61+
# layers.Conv3DTranspose:
62+
# _PrunePreserveInfo(['kernel'], []),
63+
# layers.LocallyConnected1D:
64+
# _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
65+
# layers.LocallyConnected2D:
66+
# _PrunePreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
67+
68+
# DepthwiseCon2D is supported with 8bit qat, but not with prune
69+
# layers.DepthwiseConv2D:
70+
# _PrunePreserveInfo(['depthwise_kernel'], ['Default8BitConvQuantizeConfig']),
71+
72+
# SeparableConv need verify from 8bit qat
73+
# layers.SeparableConv1D:
74+
# _PrunePreserveInfo(['pointwise_kernel'], ['Default8BitConvQuantizeConfig']),
75+
# layers.SeparableConv2D:
76+
# _PrunePreserveInfo(['pointwise_kernel'], ['Default8BitConvQuantizeConfig']),
77+
78+
# Embedding need verify from 8bit qat
79+
# layers.Embedding: _PrunePreserveInfo(['embeddings'], []),
80+
}
81+
82+
def __init__(self):
83+
84+
self._config_quantizer_map = {
85+
'Default8BitQuantizeConfig':
86+
PrunePerserveDefault8BitWeightsQuantizer(),
87+
'Default8BitConvQuantizeConfig':
88+
PrunePerserveDefault8BitConvWeightsQuantizer(),
89+
}
90+
91+
@classmethod
92+
def _no_trainable_weights(cls, layer):
93+
"""Returns whether this layer has trainable weights.
94+
95+
Args:
96+
layer: The layer to check for trainable weights.
97+
98+
Returns:
99+
True/False whether the layer has trainable weights.
100+
"""
101+
102+
return len(layer.trainable_weights) == 0
103+
104+
@classmethod
105+
def supports(cls, layer):
106+
"""Returns whether the registry supports this layer type.
107+
108+
Args:
109+
layer: The layer to check for support.
110+
111+
Returns:
112+
True/False whether the layer type is supported.
113+
114+
"""
115+
116+
# layers without trainable weights are consider supported,
117+
# e.g., ReLU, Softmax, and AveragePooling2D.
118+
if cls._no_trainable_weights(layer):
119+
return True
120+
121+
if layer.__class__ in cls._LAYERS_CONFIG_MAP:
122+
return True
123+
124+
return False
125+
126+
@classmethod
127+
def _weight_names(cls, layer):
128+
129+
if cls._no_trainable_weights(layer):
130+
return []
131+
132+
return cls._LAYERS_CONFIG_MAP[layer.__class__].weight_attrs
133+
134+
@classmethod
135+
def get_sparsity_preservable_weights(cls, layer):
136+
"""Get sparsity preservable weights from keras layer
137+
138+
Args:
139+
layer: instance of keras layer
140+
141+
Returns:
142+
List of sparsity preservable weights
143+
"""
144+
return [getattr(layer, weight) for weight in cls._weight_names(layer)]
145+
146+
@classmethod
147+
def get_suppport_quantize_config_names(cls, layer):
148+
"""Get class name of supported quantize config for layer
149+
150+
Args:
151+
layer: instance of keras layer
152+
153+
Returns:
154+
List of supported quantize config class name.
155+
"""
156+
157+
# layers without trainable weights don't need quantize_config for pqat
158+
if cls._no_trainable_weights(layer):
159+
return []
160+
161+
return cls._LAYERS_CONFIG_MAP[layer.__class__].quantize_config_attrs
162+
163+
def apply_sparsity_preserve_quantize_config(self, layer, quantize_config):
164+
""" apply weights sparsity preservation
165+
166+
Args:
167+
layer: The layer to check for support.
168+
quantize_config: quantization config to check for support,
169+
apply sparsity preservation to pruned weights
170+
171+
Returns:
172+
Returns quantize_config with addon sparsity preserve weight_quantizer.
173+
"""
174+
if self.supports(layer):
175+
if self._no_trainable_weights(layer):
176+
return quantize_config
177+
if (quantize_config.__class__.__name__
178+
in self._LAYERS_CONFIG_MAP[layer.__class__].quantize_config_attrs):
179+
quantize_config.weight_quantizer = self._config_quantizer_map[
180+
quantize_config.__class__.__name__]
181+
else:
182+
raise ValueError('Configuration ' +
183+
str(quantize_config.__class__.__name__) +
184+
' is not supported for Layer ' +
185+
str(layer.__class__) + '.')
186+
else:
187+
raise ValueError('Layer ' + str(layer.__class__) + ' is not supported.')
188+
189+
return quantize_config
190+
191+
192+
class PrunePerserveDefaultWeightsQuantizer(quantizers.LastValueQuantizer):
193+
"""Quantize weights while preserve sparsity."""
194+
def __init__(self, num_bits, per_axis, symmetric, narrow_range):
195+
"""PrunePerserveDefaultWeightsQuantizer
196+
197+
Args:
198+
num_bits: Number of bits for quantization
199+
per_axis: Whether to apply per_axis quantization. The last dimension is
200+
used as the axis.
201+
symmetric: If true, use symmetric quantization limits instead of training
202+
the minimum and maximum of each quantization range separately.
203+
narrow_range: In case of 8 bits, narrow_range nudges the quantized range
204+
to be [-127, 127] instead of [-128, 127]. This ensures symmetric
205+
range has 0 as the centre.
206+
"""
207+
208+
super(PrunePerserveDefaultWeightsQuantizer, self).__init__(
209+
num_bits=num_bits,
210+
per_axis=per_axis,
211+
symmetric=symmetric,
212+
narrow_range=narrow_range,
213+
)
214+
215+
def _build_sparsity_mask(self, name, layer):
216+
weights = getattr(layer.layer, name)
217+
sparsity_mask = tf.math.divide_no_nan(weights, weights)
218+
219+
return {'sparsity_mask': sparsity_mask}
220+
221+
def build(self, tensor_shape, name, layer):
222+
"""Construct mask to preserve weights sparsity.
223+
224+
Args:
225+
tensor_shape: Shape of weights which needs to be quantized.
226+
name: Name of weights in layer.
227+
layer: quantization wrapped keras layer.
228+
229+
Returns: Dictionary of constructed sparsity mask and
230+
quantization params, the dictionary will be passed
231+
to __call__ function.
232+
"""
233+
result = self._build_sparsity_mask(name, layer)
234+
result.update(
235+
super(PrunePerserveDefaultWeightsQuantizer,
236+
self).build(tensor_shape, name, layer))
237+
return result
238+
239+
def __call__(self, inputs, training, weights, **kwargs):
240+
"""Apply sparsity preserved quantization to the input tensor.
241+
242+
Args:
243+
inputs: Input tensor (layer's weights) to be quantized.
244+
training: Whether the graph is currently training.
245+
weights: Dictionary of weights (params) the quantizer can use to
246+
quantize the tensor (layer's weights). This contains the weights
247+
created in the `build` function.
248+
**kwargs: Additional variables which may be passed to the quantizer.
249+
250+
Returns: quantized tensor.
251+
"""
252+
253+
prune_preserve_inputs = tf.multiply(inputs, weights['sparsity_mask'])
254+
255+
return quant_ops.LastValueQuantize(
256+
prune_preserve_inputs,
257+
weights['min_var'],
258+
weights['max_var'],
259+
is_training=training,
260+
num_bits=self.num_bits,
261+
per_channel=self.per_axis,
262+
symmetric=self.symmetric,
263+
narrow_range=self.narrow_range,
264+
)
265+
266+
267+
class PrunePerserveDefault8BitWeightsQuantizer(
268+
PrunePerserveDefaultWeightsQuantizer):
269+
"""PrunePerserveWeightsQuantizer for default 8bit weights"""
270+
def __init__(self):
271+
super(PrunePerserveDefault8BitWeightsQuantizer,
272+
self).__init__(num_bits=8,
273+
per_axis=False,
274+
symmetric=True,
275+
narrow_range=True)
276+
277+
278+
class PrunePerserveDefault8BitConvWeightsQuantizer(
279+
PrunePerserveDefaultWeightsQuantizer,
280+
default_8bit_quantizers.Default8BitConvWeightsQuantizer,
281+
):
282+
"""PrunePerserveWeightsQuantizer for default 8bit Conv2D/DepthwiseConv2D weights"""
283+
def __init__(self):
284+
default_8bit_quantizers.Default8BitConvWeightsQuantizer.__init__(self)
285+
286+
def build(self, tensor_shape, name, layer):
287+
result = PrunePerserveDefaultWeightsQuantizer._build_sparsity_mask(
288+
self, name, layer)
289+
result.update(
290+
default_8bit_quantizers.Default8BitConvWeightsQuantizer.build(
291+
self, tensor_shape, name, layer))
292+
return result

0 commit comments

Comments
 (0)