Skip to content

Commit 56e9194

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
QuantizeAnnotate add registry support.
PiperOrigin-RevId: 255683228
1 parent 031ef21 commit 56e9194

File tree

3 files changed

+61
-5
lines changed

3 files changed

+61
-5
lines changed

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ py_library(
107107
srcs_version = "PY2AND3",
108108
visibility = ["//visibility:public"],
109109
deps = [
110+
":quantize_emulatable_layer",
111+
":quantize_emulate_registry",
110112
# tensorflow dep1,
111113
# python/keras tensorflow dep2,
112114
],
@@ -122,6 +124,7 @@ py_test(
122124
visibility = ["//visibility:public"],
123125
deps = [
124126
":quantize_annotate",
127+
":quantize_emulatable_layer",
125128
# tensorflow dep1,
126129
# python/keras tensorflow dep2,
127130
],

tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
from __future__ import print_function
2121

2222
from tensorflow.python.keras.layers.wrappers import Wrapper
23+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_emulatable_layer
24+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_emulate_registry
25+
26+
QuantizeEmulatableLayer = quantize_emulatable_layer.QuantizeEmulatableLayer
27+
QuantizeEmulateRegistry = quantize_emulate_registry.QuantizeEmulateRegistry
2328

2429

2530
class QuantizeAnnotate(Wrapper):
@@ -35,6 +40,10 @@ class QuantizeAnnotate(Wrapper):
3540
modified.
3641
"""
3742

43+
_UNSUPPORTED_LAYER_ERROR_MSG = (
44+
'Layer {} not supported for quantization. Layer should either inherit '
45+
'QuantizeEmulatableLayer or be a supported keras built-in layer.')
46+
3847
def __init__(self,
3948
layer,
4049
num_bits,
@@ -52,6 +61,12 @@ def __init__(self,
5261
the minimum and maximum of each quantization range separately.
5362
**kwargs: Additional keyword arguments to be passed to the keras layer.
5463
"""
64+
65+
if not isinstance(layer, QuantizeEmulatableLayer) and \
66+
not QuantizeEmulateRegistry.supports(layer):
67+
raise ValueError(
68+
self._UNSUPPORTED_LAYER_ERROR_MSG.format(layer.__class__))
69+
5570
super(QuantizeAnnotate, self).__init__(layer, **kwargs)
5671

5772
self._num_bits = num_bits

tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate_test.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,58 @@
2424
from tensorflow.python.platform import test
2525

2626
from tensorflow_model_optimization.python.core.quantization.keras import quantize_annotate
27+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_emulatable_layer
28+
29+
QuantizeEmulatableLayer = quantize_emulatable_layer.QuantizeEmulatableLayer
2730

2831

2932
class QuantizeAnnotateTest(test.TestCase):
3033

31-
def testAppliesWrapperToAllClasses(self):
32-
layer = keras.layers.Dense(5, activation='relu', input_shape=(10,))
34+
def setUp(self):
35+
self.quant_params = {
36+
'num_bits': 8,
37+
'narrow_range': True,
38+
'symmetric': True
39+
}
40+
41+
def testRaisesErrorForUnsupportedLayer(self):
42+
class CustomLayer(keras.layers.Dense):
43+
pass
44+
45+
with self.assertRaises(ValueError):
46+
quantize_annotate.QuantizeAnnotate(CustomLayer(10), **self.quant_params)
47+
48+
def testAnnotatesCustomQuantizableLayer(self):
49+
class CustomLayerQuantizable(keras.layers.Dense, QuantizeEmulatableLayer):
50+
def get_quantizable_weights(self): # pylint: disable=g-wrong-blank-lines
51+
return [self.kernel]
52+
53+
def set_quantizable_weights(self, weights):
54+
self.kernel = weights[0]
3355

56+
annotated_layer = quantize_annotate.QuantizeAnnotate(
57+
CustomLayerQuantizable(10), **self.quant_params)
58+
59+
self.assertIsInstance(annotated_layer.layer, CustomLayerQuantizable)
60+
self.assertEqual(
61+
self.quant_params, annotated_layer.get_quantize_params())
62+
63+
def testAnnotatesKerasLayer(self):
64+
layer = keras.layers.Dense(5, activation='relu', input_shape=(10,))
3465
model = keras.Sequential([layer])
35-
wrapped_model = keras.Sequential([
66+
67+
annotated_model = keras.Sequential([
3668
quantize_annotate.QuantizeAnnotate(
37-
layer, num_bits=8, input_shape=(10,))])
69+
layer, input_shape=(10,), **self.quant_params)])
70+
71+
annotated_layer = annotated_model.layers[0]
72+
self.assertIsInstance(annotated_layer.layer, keras.layers.Dense)
73+
self.assertEqual(
74+
self.quant_params, annotated_layer.get_quantize_params())
3875

76+
# Annotated model should not affect computation. Returns same results.
3977
x_test = np.random.rand(10, 10)
40-
self.assertAllEqual(model.predict(x_test), wrapped_model.predict(x_test))
78+
self.assertAllEqual(model.predict(x_test), annotated_model.predict(x_test))
4179

4280
if __name__ == '__main__':
4381
test.main()

0 commit comments

Comments
 (0)