Skip to content

Commit dbcba51

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Serialize/Deserialize implementation for QuantizeWrapper
PiperOrigin-RevId: 264685336
1 parent b61065a commit dbcba51

File tree

5 files changed

+127
-17
lines changed

5 files changed

+127
-17
lines changed

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ py_test(
153153
srcs_version = "PY2AND3",
154154
visibility = ["//visibility:public"],
155155
deps = [
156+
":quantize_aware_activation",
156157
":quantize_wrapper",
157158
# numpy dep1,
158159
# tensorflow dep1,

tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,12 @@ def _should_pre_quantize(self):
7676
# Whether we apply quantize operations around activations depends on the
7777
# implementation of the specific kernel. For example, ReLUs are fused in
7878
# whereas Softmax ops are not. Should linear have post-quantize?
79-
return self.activation.__name__ in self._PRE_ACTIVATION_TYPES
79+
80+
# For custom quantizations unknown in keras, we default to post
81+
# quantization.
82+
83+
return (hasattr(self.activation, '__name__') and
84+
self.activation.__name__ in self._PRE_ACTIVATION_TYPES)
8085

8186
def _add_range_weights(self, name):
8287
min_var = self.quantize_wrapper.add_weight(
@@ -130,3 +135,19 @@ def quantizer_fn(x=x,
130135
self._max_post_activation))
131136

132137
return x
138+
139+
# `QuantizeAwareActivation` wraps the activation within a layer to perform
140+
# quantization. In the process, the layer's activation is replaced with
141+
# `QuantizeAwareActivation`.
142+
# However, when the layer is serialized and deserialized, we want the original
143+
# activation to be reconstructed. This ensures that when `QuantizeWrapper`
144+
# wraps the layer, it can again replace the original activation.
145+
146+
@classmethod
147+
def from_config(cls, config):
148+
return activations.deserialize(config['activation'])
149+
150+
def get_config(self):
151+
return {
152+
'activation': activations.serialize(self.activation)
153+
}

tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from tensorflow.python import keras
2424
from tensorflow.python.keras import activations
2525
from tensorflow.python.keras import backend as K
26+
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
27+
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
2628
from tensorflow.python.platform import test
2729

2830
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
@@ -64,8 +66,8 @@ def testAppliesQuantizationPostActivation(self):
6466
# 256 buckets.
6567
# Derived using `tf.fake_quant_with_min_max_vars`
6668
expected_activation = np.array(
67-
[0.0, 0.0, 0.0, 0.04705906, 0.09411764, 3.011765, 5.9764705]
68-
).reshape(7, 1)
69+
[0.0, 0.0, 0.0, 0.04705906, 0.09411764, 3.011765,
70+
5.9764705]).reshape(7, 1)
6971

7072
self.assertAllClose(expected_activation, model.predict(x))
7173

@@ -88,6 +90,66 @@ def testAppliesQuantizationPreAndPostActivation(self):
8890

8991
self.assertAllClose(expected_activation, model.predict(x))
9092

93+
def testSerializationReturnsWrappedActivation_BuiltInActivation(self):
94+
activation = activations.get('tanh')
95+
quantize_activation = QuantizeAwareActivation(
96+
activation, self.quantizer, 0, self.TestLayer())
97+
98+
expected_config = {
99+
'class_name': 'QuantizeAwareActivation',
100+
'config': {'activation': 'tanh'}
101+
}
102+
serialized_quantize_activation = serialize_keras_object(quantize_activation)
103+
104+
self.assertEqual(expected_config, serialized_quantize_activation)
105+
106+
deserialized_activation = deserialize_keras_object(
107+
serialized_quantize_activation,
108+
custom_objects={'QuantizeAwareActivation': QuantizeAwareActivation})
109+
110+
self.assertEqual(activation, deserialized_activation)
111+
112+
def testSerializationReturnsWrappedActivation_CustomActivation(self):
113+
class CustomActivation(object):
114+
115+
def __init__(self, key):
116+
self.key = key
117+
118+
def get_config(self):
119+
return {'key': self.key}
120+
121+
def __call__(self, *args, **kwargs):
122+
return None
123+
124+
def __eq__(self, other):
125+
return self.key == other.key
126+
127+
activation = CustomActivation('value')
128+
quantize_activation = QuantizeAwareActivation(
129+
activation, self.quantizer, 0, self.TestLayer())
130+
131+
expected_config = {
132+
'class_name': 'QuantizeAwareActivation',
133+
'config': {
134+
'activation': {
135+
'class_name': 'CustomActivation',
136+
'config': {'key': 'value'}
137+
}
138+
}
139+
}
140+
serialized_quantize_activation = serialize_keras_object(quantize_activation)
141+
142+
self.assertEqual(expected_config, serialized_quantize_activation)
143+
144+
deserialized_activation = deserialize_keras_object(
145+
serialized_quantize_activation,
146+
custom_objects={
147+
'QuantizeAwareActivation': QuantizeAwareActivation,
148+
'CustomActivation': CustomActivation
149+
})
150+
151+
self.assertEqual(activation, deserialized_activation)
152+
91153

92154
if __name__ == '__main__':
93155
test.main()

tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@
2929
from tensorflow.python.framework import dtypes
3030
from tensorflow.python.keras import backend as K
3131
from tensorflow.python.keras import initializers
32+
from tensorflow.python.keras.layers import deserialize as deserialize_layer
3233
from tensorflow.python.keras.layers.wrappers import Wrapper
3334
from tensorflow.python.keras.utils import tf_utils
35+
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
36+
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
3437

3538
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
36-
from tensorflow_model_optimization.python.core.quantization.keras import quantize_provider as quantize_provider_mod
3739

3840

3941
class QuantizeWrapper(Wrapper):
@@ -155,29 +157,27 @@ def quantizer_fn(unquantized_weight=unquantized_weight,
155157

156158
def get_config(self):
157159
base_config = super(QuantizeWrapper, self).get_config()
158-
config = {'quantize_provider': self.quantize_provider}
160+
config = {
161+
'quantize_provider': serialize_keras_object(self.quantize_provider)
162+
}
159163
return dict(list(base_config.items()) + list(config.items()))
160164

161165
@classmethod
162166
def from_config(cls, config):
163167
config = config.copy()
164168

165-
quantize_provider = config.pop('quantize_provider')
166-
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object # pylint: disable=g-import-not-at-top
167-
# TODO(pulkitb): Add all known `QuantizeProvider`s to custom_objects
168-
custom_objects = {
169-
'QuantizeProvider': quantize_provider_mod.QuantizeProvider
170-
}
171-
config['quantize_provider'] = deserialize_keras_object(
172-
quantize_provider,
169+
# QuantizeWrapper may be constructed with any QuantizeProvider and the
170+
# wrapper itself cannot know all the possible provider classes.
171+
# The deserialization code should ensure the QuantizeProvider is in keras
172+
# serialization scope.
173+
quantize_provider = deserialize_keras_object(
174+
config.pop('quantize_provider'),
173175
module_objects=globals(),
174-
custom_objects=custom_objects)
176+
custom_objects=None)
175177

176-
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
177178
layer = deserialize_layer(config.pop('layer'))
178-
config['layer'] = layer
179179

180-
return cls(**config)
180+
return cls(layer=layer, quantize_provider=quantize_provider, **config)
181181

182182
@property
183183
def trainable(self):

tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,15 @@
2525

2626
from tensorflow.python import keras
2727
from tensorflow.python.keras import layers
28+
from tensorflow.python.keras.layers import deserialize as deserialize_layer
29+
from tensorflow.python.keras.layers import serialize as serialize_layer
2830
from tensorflow.python.platform import test
2931

32+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
3033
from tensorflow_model_optimization.python.core.quantization.keras import quantize_wrapper
3134
from tensorflow_model_optimization.python.core.quantization.keras.tflite import tflite_quantize_registry
3235

36+
QuantizeAwareActivation = quantize_aware_activation.QuantizeAwareActivation
3337
QuantizeWrapper = quantize_wrapper.QuantizeWrapper
3438
TFLiteQuantizeRegistry = tflite_quantize_registry.TFLiteQuantizeRegistry
3539

@@ -140,6 +144,28 @@ def _get_quantized_weights(shape, dtype): # pylint: disable=unused-argument
140144
model.predict(inputs), -6.0, 6.0, num_bits=8, narrow_range=False)
141145
self.assertAllClose(expected_output, quantized_model.predict(inputs))
142146

147+
def testSerializationQuantizeWrapper(self):
148+
input_shape = (2,)
149+
layer = keras.layers.Dense(3)
150+
wrapper = QuantizeWrapper(
151+
layer=layer,
152+
quantize_provider=self.quantize_registry.get_quantize_provider(layer),
153+
input_shape=input_shape)
154+
155+
custom_objects = {
156+
'QuantizeAwareActivation': QuantizeAwareActivation,
157+
'QuantizeWrapper': QuantizeWrapper
158+
}
159+
custom_objects.update(tflite_quantize_registry._types_dict())
160+
161+
serialized_wrapper = serialize_layer(wrapper)
162+
with keras.utils.custom_object_scope(custom_objects):
163+
wrapper_from_config = deserialize_layer(serialized_wrapper)
164+
165+
self.assertEqual(wrapper_from_config.get_config(), wrapper.get_config())
166+
167+
# TODO(pulkitb): Add test to ensure weights are also preserved.
168+
143169

144170
if __name__ == '__main__':
145171
test.main()

0 commit comments

Comments
 (0)