Skip to content

Commit 031ef21

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
quantize_apply function to implement quantization.
This adds support for applying QuantizeAwareActivation to to annotated layers. Also adds support for implementing quantization by cloning model with existing annotation. PiperOrigin-RevId: 255680020
1 parent d85b2ab commit 031ef21

File tree

6 files changed

+279
-9
lines changed

6 files changed

+279
-9
lines changed

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ py_library(
9292
visibility = ["//visibility:public"],
9393
deps = [
9494
":quantize_annotate",
95+
":quantize_aware_activation",
9596
":quantize_emulate_wrapper",
9697
# tensorflow dep1,
9798
# python/keras tensorflow dep2,
@@ -174,6 +175,7 @@ py_test(
174175
srcs_version = "PY2AND3",
175176
visibility = ["//visibility:public"],
176177
deps = [
178+
":quantize_aware_activation",
177179
":quantize_emulate",
178180
":quantize_emulate_wrapper",
179181
# absl/testing:parameterized dep1,

tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,16 @@ def __init__(self,
6161
def call(self, inputs, training=None):
6262
return self.layer.call(inputs)
6363

64-
def get_config(self):
65-
base_config = super(QuantizeAnnotate, self).get_config()
66-
config = {
64+
def get_quantize_params(self):
65+
return {
6766
'num_bits': self._num_bits,
6867
'symmetric': self._symmetric,
6968
'narrow_range': self._narrow_range
7069
}
70+
71+
def get_config(self):
72+
base_config = super(QuantizeAnnotate, self).get_config()
73+
config = self.get_quantize_params()
7174
return dict(list(base_config.items()) + list(config.items()))
7275

7376
@classmethod

tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,12 @@ def call(self, inputs, training=None):
137137

138138
return x
139139

140+
def get_quantize_params(self):
141+
return {
142+
'num_bits': self.num_bits,
143+
'symmetric': self.symmetric,
144+
}
145+
140146
def compute_output_shape(self, input_shape):
141147
return input_shape
142148

tensorflow_model_optimization/python/core/quantization/keras/quantize_emulate.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from tensorflow.python import keras
1818

1919
from tensorflow_model_optimization.python.core.quantization.keras import quantize_annotate as quant_annotate
20+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
2021
from tensorflow_model_optimization.python.core.quantization.keras.quantize_emulate_wrapper import QuantizeEmulateWrapper
2122

2223

@@ -132,3 +133,105 @@ def _add_quant_wrapper(layer):
132133
elif isinstance(to_quantize, keras.layers.Layer):
133134
quant_params.update(**kwargs)
134135
return quant_annotate.QuantizeAnnotate(to_quantize, **quant_params)
136+
137+
138+
def quantize_apply(model):
139+
"""Apply quantization operations to a keras model.
140+
141+
This function takes a keras model which has been annotated with
142+
`quantize_annotate` and constructs a new keras model in which each of the
143+
annotated layers have been quantized. The quantization process introduces
144+
new quantization ops in the Tensorflow graph to appropriately emulate
145+
quantization loss.
146+
147+
Note that to exactly emulate quantization loss, certain graph/model
148+
transformations may be applied. This is required since the actual quantized
149+
kernel implementations may apply similar transformations.
150+
151+
Args:
152+
model: A keras Sequential or Functional model which has been annotated
153+
with `quantize_annotate`.
154+
155+
Returns:
156+
Returns a new cloned keras model in which the annotated layers have been
157+
quantized. All the existing layers are cloned.
158+
"""
159+
160+
if not isinstance(model, keras.Model):
161+
raise ValueError('Only a keras `Model` instance can be used.')
162+
163+
if not isinstance(model, keras.Sequential) \
164+
and not model._is_graph_network: # pylint: disable=protected-access
165+
raise ValueError('model should be either a keras.Sequential or a '
166+
'keras functional model.')
167+
168+
# Have at least 1 layer annotated with QuantizeAnnotate
169+
if not any(isinstance(layer, quant_annotate.QuantizeAnnotate)
170+
for layer in model.layers):
171+
raise ValueError('model does not contain any layers which have been '
172+
'annotated with `quantize_annotate`. There are no layers '
173+
'to quantize.')
174+
175+
def _clone_layer(layer):
176+
return layer.__class__.from_config(layer.get_config())
177+
178+
def _quantize_activation(activation, parent_class, quantize_params):
179+
try:
180+
return quantize_aware_activation.QuantizeAwareActivation(
181+
activation.__name__, parent_class, **quantize_params)
182+
except TypeError:
183+
# Non-standard activation. Could be a custom callable, or an advanced
184+
# activation. Simply return the original activation for now.
185+
# TODO(pulkitb): Determine how to handle custom activations and advanced
186+
# activations.
187+
return activation
188+
189+
def _get_quantize_activation_params(layer):
190+
quant_params = layer.get_quantize_params()
191+
# narrow_range is not relevant to quantizing activations.
192+
quant_params.pop('narrow_range')
193+
194+
return quant_params
195+
196+
def _apply_quantization(quant_annotate_layer):
197+
layer_to_quantize = _clone_layer(quant_annotate_layer.layer)
198+
quantize_params = quant_annotate_layer.get_quantize_params()
199+
200+
return QuantizeEmulateWrapper(layer_to_quantize, **quantize_params)
201+
202+
# Apply all graph level transformations.
203+
replace_map = {}
204+
205+
# Replace activations in layers with QuantAwareActivation.
206+
# Dense(activation='relu') -> Dense(activation=QuantAwareActivation('relu'))
207+
# TODO(pulkitb): Not all layers (LSTMs) have just activation. Add
208+
# generic handling for all layers.
209+
for layer in model.layers:
210+
if isinstance(layer, quant_annotate.QuantizeAnnotate) and \
211+
(layer.layer.activation is not None and
212+
layer.layer.activation != keras.activations.linear):
213+
quantized_layer = _apply_quantization(layer)
214+
215+
quantized_layer.layer.activation = _quantize_activation(
216+
layer.layer.activation, layer.layer.__class__,
217+
_get_quantize_activation_params(layer))
218+
219+
replace_map[layer] = quantized_layer
220+
221+
# TODO(pulkitb): Transform [Dense(), ReLU()] to be quant aware.
222+
223+
def _add_quant_emulate_wrapper(layer): # pylint: disable=missing-docstring
224+
# Quantized layer has been constructed during graph transformation. Return.
225+
if layer in replace_map:
226+
return replace_map[layer]
227+
228+
# No need to quantize layer. Simply clone and return.
229+
if not isinstance(layer, quant_annotate.QuantizeAnnotate):
230+
return _clone_layer(layer)
231+
232+
# Use QuantizeEmulate wrapper on annotated layer which actually
233+
# quantization ops.
234+
return _apply_quantization(layer)
235+
236+
return keras.models.clone_model(
237+
model, input_tensors=None, clone_function=_add_quant_emulate_wrapper)

tensorflow_model_optimization/python/core/quantization/keras/quantize_emulate_test.py

Lines changed: 156 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
"""Tests for keras pruning wrapper."""
15+
"""Tests for quantize API functions."""
1616

1717
from __future__ import absolute_import
1818
from __future__ import division
@@ -23,11 +23,14 @@
2323
from tensorflow.python import keras
2424
from tensorflow.python.platform import test
2525
from tensorflow_model_optimization.python.core.quantization.keras import quantize_annotate as quant_annotate
26+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
2627
from tensorflow_model_optimization.python.core.quantization.keras import quantize_emulate
27-
from tensorflow_model_optimization.python.core.quantization.keras.quantize_emulate import QuantizeEmulate
28-
from tensorflow_model_optimization.python.core.quantization.keras.quantize_emulate_wrapper import QuantizeEmulateWrapper
28+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_emulate_wrapper
2929

3030
quantize_annotate = quantize_emulate.quantize_annotate
31+
QuantizeEmulate = quantize_emulate.QuantizeEmulate
32+
QuantizeEmulateWrapper = quantize_emulate_wrapper.QuantizeEmulateWrapper
33+
QuantizeAwareActivation = quantize_aware_activation.QuantizeAwareActivation
3134

3235

3336
class QuantizeEmulateTest(test.TestCase):
@@ -131,5 +134,155 @@ def testQuantizeAnnotateModel_HasAnnotatedLayers(self):
131134
self.assertAllEqual(model.predict(inputs), annotated_model.predict(inputs))
132135

133136

137+
class QuantizeApplyTest(test.TestCase):
138+
139+
def setUp(self):
140+
self.quant_params1 = {
141+
'num_bits': 8,
142+
'narrow_range': True,
143+
'symmetric': True
144+
}
145+
self.quant_params2 = {
146+
'num_bits': 4,
147+
'narrow_range': False,
148+
'symmetric': False
149+
}
150+
151+
# Validation tests
152+
153+
def testRaisesErrorIfNotKerasModel(self):
154+
with self.assertRaises(ValueError):
155+
quantize_emulate.quantize_apply(keras.layers.Dense(32))
156+
157+
def testRaisesErrorIfKerasSubclassedModel(self):
158+
class MyModel(keras.Model):
159+
def call(self, inputs, training=None, mask=None): # pylint: disable=g-wrong-blank-lines
160+
return inputs
161+
162+
with self.assertRaises(ValueError):
163+
quantize_emulate.quantize_apply(MyModel())
164+
165+
def testRaisesErrorNoAnnotatedLayers_Sequential(self):
166+
model = keras.Sequential([
167+
keras.layers.Dense(10), keras.layers.Dropout(0.4)])
168+
169+
with self.assertRaises(ValueError):
170+
quantize_emulate.quantize_apply(model)
171+
172+
def testRaisesErrorNoAnnotatedLayers_Functional(self):
173+
inputs = keras.Input(shape=(10,))
174+
x = keras.layers.Dense(32, activation='relu')(inputs)
175+
results = keras.layers.Dense(5, activation='softmax')(x)
176+
model = keras.Model(inputs=inputs, outputs=results)
177+
178+
with self.assertRaises(ValueError):
179+
quantize_emulate.quantize_apply(model)
180+
181+
# Quantization Apply Tests
182+
183+
def _get_annotated_sequential_model(self):
184+
return keras.Sequential([
185+
quantize_annotate(keras.layers.Conv2D(32, 5), input_shape=(28, 28, 1),
186+
**self.quant_params1),
187+
quantize_annotate(keras.layers.Dense(10), **self.quant_params2)
188+
])
189+
190+
def _get_annotated_functional_model(self):
191+
inputs = keras.Input(shape=(28, 28, 1))
192+
x = quantize_annotate(
193+
keras.layers.Conv2D(32, 5), **self.quant_params1)(inputs)
194+
results = quantize_annotate(keras.layers.Dense(10), **self.quant_params2)(x)
195+
196+
return keras.Model(inputs=inputs, outputs=results)
197+
198+
def _assert_layer_emulated(
199+
self, annotated_layer, emulated_layer, exclude_keys=None):
200+
self.assertIsInstance(emulated_layer, QuantizeEmulateWrapper)
201+
202+
self.assertEqual(annotated_layer.get_quantize_params(),
203+
emulated_layer.get_quantize_params())
204+
205+
# Extract configs of the inner layers they wrap.
206+
annotated_config = annotated_layer.layer.get_config()
207+
emulated_config = emulated_layer.layer.get_config()
208+
209+
# The underlying layers aren't always exactly the same. For example,
210+
# activations in the underlying layers might be replaced. Exclude keys
211+
# if required.
212+
if exclude_keys:
213+
for key in exclude_keys:
214+
annotated_config.pop(key)
215+
emulated_config.pop(key)
216+
217+
self.assertEqual(annotated_config, emulated_config)
218+
219+
def _assert_model_emulated(
220+
self, annotated_model, emulated_model, exclude_keys=None):
221+
for annotated_layer, emulated_layer in zip(annotated_model.layers,
222+
emulated_model.layers):
223+
if isinstance(emulated_layer, keras.layers.InputLayer):
224+
continue
225+
226+
self._assert_layer_emulated(annotated_layer, emulated_layer, exclude_keys)
227+
228+
def testAppliesQuantizationToAnnotatedModel_Sequential(self):
229+
model = self._get_annotated_sequential_model()
230+
231+
quantized_model = quantize_emulate.quantize_apply(model)
232+
233+
self._assert_model_emulated(model, quantized_model)
234+
235+
def testAppliesQuantizationToAnnotatedModel_Functional(self):
236+
model = self._get_annotated_functional_model()
237+
238+
quantized_model = quantize_emulate.quantize_apply(model)
239+
240+
self._assert_model_emulated(model, quantized_model)
241+
242+
# Transformation Tests
243+
244+
def testQuantizesActivationsWithinLayer_Sequential(self):
245+
quant_params = {'num_bits': 8, 'symmetric': True}
246+
model = keras.Sequential([
247+
quantize_annotate(
248+
keras.layers.Conv2D(32, 5, activation='relu'),
249+
input_shape=(28, 28, 1),
250+
**quant_params)
251+
])
252+
253+
quantized_model = quantize_emulate.quantize_apply(model)
254+
255+
# We expect activation to be modified.
256+
self._assert_model_emulated(model, quantized_model, ['activation'])
257+
258+
conv_layer = quantized_model.layers[0].layer
259+
self.assertIsInstance(conv_layer.activation, QuantizeAwareActivation)
260+
self.assertEqual(
261+
keras.activations.get('relu'), conv_layer.activation.activation)
262+
self.assertEqual(keras.layers.Conv2D, conv_layer.activation.parent_layer)
263+
self.assertEqual(quant_params, conv_layer.activation.get_quantize_params())
264+
265+
def testQuantizesActivationsWithinLayer_Functional(self):
266+
quant_params = {'num_bits': 8, 'symmetric': True}
267+
268+
inputs = keras.Input(shape=(28, 28, 1))
269+
results = quantize_annotate(
270+
keras.layers.Conv2D(32, 5, activation='relu'),
271+
**self.quant_params1)(inputs)
272+
model = keras.Model(inputs=inputs, outputs=results)
273+
274+
quantized_model = quantize_emulate.quantize_apply(model)
275+
276+
# We expect activation to be modified.
277+
self._assert_model_emulated(model, quantized_model, ['activation'])
278+
279+
conv_layer = quantized_model.layers[1].layer
280+
self.assertIsInstance(conv_layer.activation, QuantizeAwareActivation)
281+
self.assertEqual(
282+
keras.activations.get('relu'), conv_layer.activation.activation)
283+
self.assertEqual(keras.layers.Conv2D, conv_layer.activation.parent_layer)
284+
self.assertEqual(quant_params, conv_layer.activation.get_quantize_params())
285+
286+
134287
if __name__ == '__main__':
135288
test.main()

tensorflow_model_optimization/python/core/quantization/keras/quantize_emulate_wrapper.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,16 @@ def fn():
185185

186186
return outputs
187187

188-
def get_config(self):
189-
base_config = super(QuantizeEmulateWrapper, self).get_config()
190-
config = {
188+
def get_quantize_params(self):
189+
return {
191190
'num_bits': self._num_bits,
192191
'symmetric': self._symmetric,
193192
'narrow_range': self._narrow_range
194193
}
194+
195+
def get_config(self):
196+
base_config = super(QuantizeEmulateWrapper, self).get_config()
197+
config = self.get_quantize_params()
195198
return dict(list(base_config.items()) + list(config.items()))
196199

197200
@classmethod

0 commit comments

Comments
 (0)