Skip to content

Commit fb389ff

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Move quantize_annotate to new API structure.
This CL begins the process to migrate to the new API structure. Changing in entirety would break lots of tests, or require a large CL changing the entire code. This CL only changes the quantize_annotate layers. PiperOrigin-RevId: 263236324
1 parent e7f003a commit fb389ff

File tree

5 files changed

+70
-120
lines changed

5 files changed

+70
-120
lines changed

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,7 @@ py_library(
183183
srcs_version = "PY2AND3",
184184
visibility = ["//visibility:public"],
185185
deps = [
186-
":quantize_emulatable_layer",
187-
":quantize_emulate_registry",
186+
":quantize_provider",
188187
# tensorflow dep1,
189188
# python/keras tensorflow dep2,
190189
],
@@ -200,7 +199,7 @@ py_test(
200199
visibility = ["//visibility:public"],
201200
deps = [
202201
":quantize_annotate",
203-
":quantize_emulatable_layer",
202+
":quantize_provider",
204203
# tensorflow dep1,
205204
# python/keras tensorflow dep2,
206205
],

tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,7 @@
2020
from __future__ import print_function
2121

2222
from tensorflow.python.keras.layers.wrappers import Wrapper
23-
from tensorflow_model_optimization.python.core.quantization.keras import quantize_emulatable_layer
24-
from tensorflow_model_optimization.python.core.quantization.keras import quantize_emulate_registry
25-
26-
QuantizeEmulatableLayer = quantize_emulatable_layer.QuantizeEmulatableLayer
27-
QuantizeEmulateRegistry = quantize_emulate_registry.QuantizeEmulateRegistry
23+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_provider as quantize_provider_mod
2824

2925

3026
class QuantizeAnnotate(Wrapper):
@@ -46,52 +42,52 @@ class QuantizeAnnotate(Wrapper):
4642

4743
def __init__(self,
4844
layer,
49-
num_bits,
50-
narrow_range=True,
51-
symmetric=True,
45+
quantize_provider=None,
5246
**kwargs):
5347
"""Create a quantize annotate wrapper over a keras layer.
5448
5549
Args:
5650
layer: The keras layer to be quantized.
57-
num_bits: Number of bits for quantization
58-
narrow_range: Whether to use the narrow quantization range [1; 2^num_bits
59-
- 1] or wide range [0; 2^num_bits - 1].
60-
symmetric: If true, use symmetric quantization limits instead of training
61-
the minimum and maximum of each quantization range separately.
51+
quantize_provider: `QuantizeProvider` to quantize layer.
6252
**kwargs: Additional keyword arguments to be passed to the keras layer.
6353
"""
64-
65-
if not isinstance(layer, QuantizeEmulatableLayer) and \
66-
not QuantizeEmulateRegistry.supports(layer):
67-
raise ValueError(
68-
self._UNSUPPORTED_LAYER_ERROR_MSG.format(layer.__class__))
69-
7054
super(QuantizeAnnotate, self).__init__(layer, **kwargs)
7155

72-
self._num_bits = num_bits
73-
self._narrow_range = narrow_range
74-
self._symmetric = symmetric
56+
self.quantize_provider = quantize_provider
7557

7658
def call(self, inputs, training=None):
7759
return self.layer.call(inputs)
7860

7961
def get_quantize_params(self):
62+
# TODO(pulkitb): Keep around function so rest of code works. Remove later.
8063
return {
81-
'num_bits': self._num_bits,
82-
'symmetric': self._symmetric,
83-
'narrow_range': self._narrow_range
64+
'num_bits': 8,
65+
'symmetric': True,
66+
'narrow_range': True
8467
}
8568

8669
def get_config(self):
8770
base_config = super(QuantizeAnnotate, self).get_config()
88-
config = self.get_quantize_params()
71+
config = {
72+
'quantize_provider': self.quantize_provider
73+
}
8974
return dict(list(base_config.items()) + list(config.items()))
9075

9176
@classmethod
9277
def from_config(cls, config):
9378
config = config.copy()
9479

80+
quantize_provider = config.pop('quantize_provider')
81+
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object # pylint: disable=g-import-not-at-top
82+
# TODO(pulkitb): Add all known `QuantizeProvider`s to custom_objects
83+
custom_objects = {
84+
'QuantizeProvider': quantize_provider_mod.QuantizeProvider
85+
}
86+
config['quantize_provider'] = deserialize_keras_object(
87+
quantize_provider,
88+
module_objects=globals(),
89+
custom_objects=custom_objects)
90+
9591
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
9692
layer = deserialize_layer(config.pop('layer'))
9793
config['layer'] = layer

tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate_test.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,58 +24,36 @@
2424
from tensorflow.python.platform import test
2525

2626
from tensorflow_model_optimization.python.core.quantization.keras import quantize_annotate
27-
from tensorflow_model_optimization.python.core.quantization.keras import quantize_emulatable_layer
28-
29-
QuantizeEmulatableLayer = quantize_emulatable_layer.QuantizeEmulatableLayer
27+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_provider as quantize_provider_mod
3028

3129

3230
class QuantizeAnnotateTest(test.TestCase):
3331

34-
def setUp(self):
35-
self.quant_params = {
36-
'num_bits': 8,
37-
'narrow_range': True,
38-
'symmetric': True
39-
}
32+
class TestQuantizeProvider(quantize_provider_mod.QuantizeProvider):
4033

41-
def testRaisesErrorForUnsupportedLayer(self):
42-
class CustomLayer(keras.layers.Dense):
34+
def get_weights_and_quantizers(self, layer):
4335
pass
4436

45-
with self.assertRaises(ValueError):
46-
quantize_annotate.QuantizeAnnotate(CustomLayer(10), **self.quant_params)
47-
48-
def testAnnotatesCustomQuantizableLayer(self):
49-
class CustomLayerQuantizable(keras.layers.Dense, QuantizeEmulatableLayer):
50-
def get_quantizable_weights(self): # pylint: disable=g-wrong-blank-lines
51-
return [self.kernel]
52-
53-
def set_quantizable_weights(self, weights):
54-
self.kernel = weights[0]
55-
56-
annotated_layer = quantize_annotate.QuantizeAnnotate(
57-
CustomLayerQuantizable(10), **self.quant_params)
58-
59-
self.assertIsInstance(annotated_layer.layer, CustomLayerQuantizable)
60-
self.assertEqual(
61-
self.quant_params, annotated_layer.get_quantize_params())
37+
def get_activations_and_quantizers(self, layer):
38+
pass
6239

6340
def testAnnotatesKerasLayer(self):
6441
layer = keras.layers.Dense(5, activation='relu', input_shape=(10,))
6542
model = keras.Sequential([layer])
6643

44+
quantize_provider = self.TestQuantizeProvider()
6745
annotated_model = keras.Sequential([
6846
quantize_annotate.QuantizeAnnotate(
69-
layer, input_shape=(10,), **self.quant_params)])
47+
layer, quantize_provider=quantize_provider, input_shape=(10,))])
7048

7149
annotated_layer = annotated_model.layers[0]
72-
self.assertIsInstance(annotated_layer.layer, keras.layers.Dense)
73-
self.assertEqual(
74-
self.quant_params, annotated_layer.get_quantize_params())
50+
self.assertEqual(layer, annotated_layer.layer)
51+
self.assertEqual(quantize_provider, annotated_layer.quantize_provider)
7552

7653
# Annotated model should not affect computation. Returns same results.
7754
x_test = np.random.rand(10, 10)
7855
self.assertAllEqual(model.predict(x_test), annotated_model.predict(x_test))
7956

57+
8058
if __name__ == '__main__':
8159
test.main()

tensorflow_model_optimization/python/core/quantization/keras/quantize_emulate.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,7 @@ def _QuantizeList(layers, **params):
8888

8989

9090
# TODO(pulkitb): Enable lint naming is fixed and made consistent.
91-
def quantize_annotate(
92-
to_quantize,
93-
num_bits,
94-
narrow_range=True,
95-
symmetric=True,
96-
**kwargs): # pylint: disable=invalid-name
91+
def quantize_annotate(to_quantize, **kwargs): # pylint: disable=invalid-name
9792
"""Specify a layer or model to be quantized.
9893
9994
This function does not apply an quantization emulation operations. It merely
@@ -102,11 +97,6 @@ def quantize_annotate(
10297
10398
Args:
10499
to_quantize: Keras layer or model to be quantized.
105-
num_bits: Number of bits for quantization
106-
narrow_range: Whether to use the narrow quantization range [1; 2^num_bits
107-
- 1] or wide range [0; 2^num_bits - 1].
108-
symmetric: If true, use symmetric quantization limits instead of training
109-
the minimum and maximum of each quantization range separately.
110100
**kwargs: Additional keyword arguments to be passed to the keras layer.
111101
112102
Returns:
@@ -119,20 +109,15 @@ def _add_quant_wrapper(layer):
119109
if isinstance(layer, quant_annotate.QuantizeAnnotate):
120110
return layer
121111

122-
return quant_annotate.QuantizeAnnotate(layer, **quant_params)
123-
124-
quant_params = {
125-
'num_bits': num_bits,
126-
'narrow_range': narrow_range,
127-
'symmetric': symmetric
128-
}
112+
return quant_annotate.QuantizeAnnotate(layer)
129113

130114
if isinstance(to_quantize, keras.Model):
131115
return keras.models.clone_model(
132116
to_quantize, input_tensors=None, clone_function=_add_quant_wrapper)
133117
elif isinstance(to_quantize, keras.layers.Layer):
134-
quant_params.update(**kwargs)
135-
return quant_annotate.QuantizeAnnotate(to_quantize, **quant_params)
118+
# TODO(pulkitb): Since annotation for model and layer have different
119+
# parameters, we should likely remove support for layers here.
120+
return quant_annotate.QuantizeAnnotate(to_quantize, **kwargs)
136121

137122

138123
def quantize_apply(model):

tensorflow_model_optimization/python/core/quantization/keras/quantize_emulate_test.py

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
2828
from tensorflow_model_optimization.python.core.quantization.keras import quantize_emulate
2929
from tensorflow_model_optimization.python.core.quantization.keras import quantize_emulate_wrapper
30+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_provider as quantize_provider_mod
3031

3132
quantize_annotate = quantize_emulate.quantize_annotate
3233
QuantizeEmulate = quantize_emulate.QuantizeEmulate
@@ -67,35 +68,20 @@ def testQuantizeEmulateList(self):
6768

6869
class QuantizeAnnotateTest(test.TestCase):
6970

70-
def setUp(self):
71-
self.quant_params = {
72-
'num_bits': 8,
73-
'narrow_range': True,
74-
'symmetric': True
75-
}
76-
77-
def _assertQuantParams(self, layer, quant_params):
78-
layer_params = {
79-
'num_bits': layer._num_bits,
80-
'narrow_range': layer._narrow_range,
81-
'symmetric': layer._symmetric
82-
}
83-
self.assertEqual(quant_params, layer_params)
84-
85-
def _assertWrappedLayer(self, layer, quant_params):
71+
def _assertWrappedLayer(self, layer, quantize_provider=None):
8672
self.assertIsInstance(layer, quant_annotate.QuantizeAnnotate)
87-
self._assertQuantParams(layer, quant_params)
73+
self.assertEqual(quantize_provider, layer.quantize_provider)
8874

89-
def _assertWrappedSequential(self, model, quant_params):
75+
def _assertWrappedModel(self, model):
9076
for layer in model.layers:
91-
self._assertWrappedLayer(layer, quant_params)
77+
self._assertWrappedLayer(layer)
9278

9379
def testQuantizeAnnotateLayer(self):
9480
layer = keras.layers.Dense(10, input_shape=(5,))
9581
wrapped_layer = quantize_annotate(
96-
layer, input_shape=(5,), **self.quant_params)
82+
layer, input_shape=(5,))
9783

98-
self._assertWrappedLayer(wrapped_layer, self.quant_params)
84+
self._assertWrappedLayer(wrapped_layer)
9985

10086
inputs = np.random.rand(1, 5)
10187
model = keras.Sequential([layer])
@@ -110,24 +96,33 @@ def testQuantizeAnnotateModel(self):
11096
keras.layers.Dense(10, input_shape=(5,)),
11197
keras.layers.Dropout(0.4)
11298
])
113-
annotated_model = quantize_annotate(model, **self.quant_params)
99+
annotated_model = quantize_annotate(model)
114100

115-
self._assertWrappedSequential(annotated_model, self.quant_params)
101+
self._assertWrappedModel(annotated_model)
116102

117103
inputs = np.random.rand(1, 5)
118104
self.assertAllEqual(model.predict(inputs), annotated_model.predict(inputs))
119105

120106
def testQuantizeAnnotateModel_HasAnnotatedLayers(self):
121-
layer_params = {'num_bits': 4, 'narrow_range': False, 'symmetric': False}
107+
class TestQuantizeProvider(quantize_provider_mod.QuantizeProvider):
108+
109+
def get_weights_and_quantizers(self, layer):
110+
pass
111+
112+
def get_activations_and_quantizers(self, layer):
113+
pass
114+
115+
quantize_provider = TestQuantizeProvider()
116+
122117
model = keras.Sequential([
123118
keras.layers.Dense(10, input_shape=(5,)),
124-
quantize_annotate(keras.layers.Dense(5), **layer_params)
119+
quant_annotate.QuantizeAnnotate(
120+
keras.layers.Dense(5), quantize_provider=quantize_provider)
125121
])
122+
annotated_model = quantize_annotate(model)
126123

127-
annotated_model = quantize_annotate(model, **self.quant_params)
128-
129-
self._assertWrappedLayer(annotated_model.layers[0], self.quant_params)
130-
self._assertWrappedLayer(annotated_model.layers[1], layer_params)
124+
self._assertWrappedLayer(annotated_model.layers[0])
125+
self._assertWrappedLayer(annotated_model.layers[1], quantize_provider)
131126
# Ensure an already annotated layer is not wrapped again.
132127
self.assertIsInstance(annotated_model.layers[1].layer, keras.layers.Dense)
133128

@@ -181,7 +176,7 @@ def testRaisesErrorNoAnnotatedLayers_Functional(self):
181176

182177
def testRaisesErrorModelNotBuilt(self):
183178
model = keras.Sequential([
184-
quantize_annotate(keras.layers.Dense(10), **self.quant_params1)])
179+
quantize_annotate(keras.layers.Dense(10))])
185180

186181
self.assertFalse(model.built)
187182
with self.assertRaises(ValueError):
@@ -191,16 +186,15 @@ def testRaisesErrorModelNotBuilt(self):
191186

192187
def _get_annotated_sequential_model(self):
193188
return keras.Sequential([
194-
quantize_annotate(keras.layers.Conv2D(32, 5), input_shape=(28, 28, 1),
195-
**self.quant_params1),
196-
quantize_annotate(keras.layers.Dense(10), **self.quant_params2)
189+
quantize_annotate(keras.layers.Conv2D(32, 5), input_shape=(28, 28, 1)),
190+
quantize_annotate(keras.layers.Dense(10))
197191
])
198192

199193
def _get_annotated_functional_model(self):
200194
inputs = keras.Input(shape=(28, 28, 1))
201195
x = quantize_annotate(
202-
keras.layers.Conv2D(32, 5), **self.quant_params1)(inputs)
203-
results = quantize_annotate(keras.layers.Dense(10), **self.quant_params2)(x)
196+
keras.layers.Conv2D(32, 5))(inputs)
197+
results = quantize_annotate(keras.layers.Dense(10))(x)
204198

205199
return keras.Model(inputs=inputs, outputs=results)
206200

@@ -283,8 +277,7 @@ def testQuantizesActivationsWithinLayer_Sequential(self):
283277
model = keras.Sequential([
284278
quantize_annotate(
285279
keras.layers.Conv2D(32, 5, activation='relu'),
286-
input_shape=(28, 28, 1),
287-
**quant_params)
280+
input_shape=(28, 28, 1))
288281
])
289282

290283
quantized_model = quantize_emulate.quantize_apply(model)
@@ -304,8 +297,7 @@ def testQuantizesActivationsWithinLayer_Functional(self):
304297

305298
inputs = keras.Input(shape=(28, 28, 1))
306299
results = quantize_annotate(
307-
keras.layers.Conv2D(32, 5, activation='relu'),
308-
**self.quant_params1)(inputs)
300+
keras.layers.Conv2D(32, 5, activation='relu'))(inputs)
309301
model = keras.Model(inputs=inputs, outputs=results)
310302

311303
quantized_model = quantize_emulate.quantize_apply(model)

0 commit comments

Comments
 (0)