Skip to content

Commit 9cb79a0

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
QuantizeWrapper implementation based on new API
QuantizeWrapper implemented afresh to conform to the new API. This uses QuantizeProviders to apply quantization to both weights and activations. Due to this change, it can now deal with multiple weights/activations and cover all Keras layers. Also, adds tests similar to previous wrapper for most Keras layers. Verifies application of post activation quantize operations in addition to weights. PiperOrigin-RevId: 264457822
1 parent 624155b commit 9cb79a0

File tree

3 files changed

+412
-30
lines changed

3 files changed

+412
-30
lines changed

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 63 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,69 @@ py_library(
9898
],
9999
)
100100

101+
py_library(
102+
name = "quantize_aware_activation",
103+
srcs = [
104+
"quantize_aware_activation.py",
105+
],
106+
srcs_version = "PY2AND3",
107+
visibility = ["//visibility:public"],
108+
deps = [
109+
":quant_ops",
110+
# tensorflow dep1,
111+
# python/keras tensorflow dep2,
112+
],
113+
)
114+
115+
py_test(
116+
name = "quantize_aware_activation_test",
117+
srcs = [
118+
"quantize_aware_activation_test.py",
119+
],
120+
python_version = "PY3",
121+
srcs_version = "PY2AND3",
122+
visibility = ["//visibility:public"],
123+
deps = [
124+
":quantize_aware_activation",
125+
":quantizers",
126+
# tensorflow dep1,
127+
# python/keras tensorflow dep2,
128+
],
129+
)
130+
131+
py_library(
132+
name = "quantize_wrapper",
133+
srcs = [
134+
"quantize_wrapper.py",
135+
],
136+
srcs_version = "PY2AND3",
137+
visibility = ["//visibility:public"],
138+
deps = [
139+
":quantize_aware_activation",
140+
":quantize_provider",
141+
":quantizers",
142+
# tensorflow dep1,
143+
# python/keras tensorflow dep2,
144+
],
145+
)
146+
147+
py_test(
148+
name = "quantize_wrapper_test",
149+
srcs = [
150+
"quantize_wrapper_test.py",
151+
],
152+
python_version = "PY3",
153+
srcs_version = "PY2AND3",
154+
visibility = ["//visibility:public"],
155+
deps = [
156+
":quantize_wrapper",
157+
# numpy dep1,
158+
# tensorflow dep1,
159+
# python/keras tensorflow dep2,
160+
"//tensorflow_model_optimization/python/core/quantization/keras/tflite:tflite_quantize_registry",
161+
],
162+
)
163+
101164
py_library(
102165
name = "quantize_emulatable_layer",
103166
srcs = [
@@ -205,36 +268,6 @@ py_test(
205268
],
206269
)
207270

208-
py_library(
209-
name = "quantize_aware_activation",
210-
srcs = [
211-
"quantize_aware_activation.py",
212-
],
213-
srcs_version = "PY2AND3",
214-
visibility = ["//visibility:public"],
215-
deps = [
216-
":quant_ops",
217-
# tensorflow dep1,
218-
# python/keras tensorflow dep2,
219-
],
220-
)
221-
222-
py_test(
223-
name = "quantize_aware_activation_test",
224-
srcs = [
225-
"quantize_aware_activation_test.py",
226-
],
227-
python_version = "PY3",
228-
srcs_version = "PY2AND3",
229-
visibility = ["//visibility:public"],
230-
deps = [
231-
":quantize_aware_activation",
232-
":quantizers",
233-
# tensorflow dep1,
234-
# python/keras tensorflow dep2,
235-
],
236-
)
237-
238271
py_library(
239272
name = "utils",
240273
srcs = [
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Wrapper which applies quantization operations over underlying layer.
16+
17+
`QuantizeWrapper` is responsible for modifying the construction of the
18+
underlying layer to ensure proper quantization operations are placed in the
19+
graph.
20+
21+
These operations ensure proper introduction of inference time losses during
22+
training.
23+
"""
24+
25+
from __future__ import absolute_import
26+
from __future__ import division
27+
from __future__ import print_function
28+
29+
from tensorflow.python.framework import dtypes
30+
from tensorflow.python.keras import backend as K
31+
from tensorflow.python.keras import initializers
32+
from tensorflow.python.keras.layers.wrappers import Wrapper
33+
from tensorflow.python.keras.utils import tf_utils
34+
35+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
36+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_provider as quantize_provider_mod
37+
38+
39+
class QuantizeWrapper(Wrapper):
40+
"""Quantizes the weights and activations of the keras layer it wraps."""
41+
42+
def __init__(self, layer, quantize_provider, **kwargs):
43+
"""Create a quantize emulate wrapper for a keras layer.
44+
45+
Args:
46+
layer: The keras layer to be quantized.
47+
quantize_provider: `QuantizeProvider` to quantize layer.
48+
**kwargs: Additional keyword arguments to be passed to the keras layer.
49+
"""
50+
51+
if quantize_provider is None:
52+
raise ValueError('quantize_provider cannot be None. It is needed to '
53+
'quantize a layer.')
54+
55+
super(QuantizeWrapper, self).__init__(layer, **kwargs)
56+
self.quantize_provider = quantize_provider
57+
58+
# Ensures cloning of already built layer works.
59+
if (not hasattr(self, '_batch_input_shape') and
60+
hasattr(layer, '_batch_input_shape')):
61+
self._batch_input_shape = self.layer._batch_input_shape # pylint: disable=protected-access
62+
self._track_trackable(layer, name='layer')
63+
64+
@staticmethod
65+
def _weight_name(name):
66+
"""Extracts the weight name from the full TensorFlow variable name.
67+
68+
For example, returns 'kernel' for 'dense_2/kernel:0'.
69+
70+
Args:
71+
name: TensorFlow variable name.
72+
73+
Returns:
74+
Extracted weight name.
75+
"""
76+
return name.split(':')[0].split('/')[-1]
77+
78+
def _add_range_weights(self, name):
79+
min_weight = self.add_weight(
80+
name + '_min', initializer=initializers.Constant(-6.0), trainable=False)
81+
max_weight = self.add_weight(
82+
name + '_max', initializer=initializers.Constant(6.0), trainable=False)
83+
84+
return min_weight, max_weight
85+
86+
def build(self, input_shape):
87+
super(QuantizeWrapper, self).build(input_shape)
88+
89+
self.optimizer_step = self.add_weight(
90+
'optimizer_step',
91+
initializer=initializers.Constant(-1),
92+
dtype=dtypes.int32,
93+
trainable=False)
94+
95+
self._weight_vars = []
96+
for weight, quantizer in \
97+
self.quantize_provider.get_weights_and_quantizers(self.layer):
98+
min_var, max_var = self._add_range_weights(self._weight_name(weight.name))
99+
100+
self._weight_vars.append((weight, quantizer, min_var, max_var))
101+
# Needed to ensure unquantized weights get trained as part of the wrapper.
102+
self._trainable_weights.append(weight)
103+
104+
self._quantize_activations = []
105+
for activation, quantizer in \
106+
self.quantize_provider.get_activations_and_quantizers(self.layer):
107+
quantize_activation = quantize_aware_activation.QuantizeAwareActivation(
108+
activation, quantizer, self.optimizer_step, self)
109+
110+
self._quantize_activations.append(quantize_activation)
111+
112+
def compute_output_shape(self, input_shape):
113+
return self.layer.compute_output_shape(self.layer.input_shape)
114+
115+
def _dict_vars(self, min_var, max_var):
116+
return {'min_var': min_var, 'max_var': max_var}
117+
118+
def call(self, inputs, training=None):
119+
if training is None:
120+
training = K.learning_phase()
121+
122+
# Quantize all weights, and replace them in the underlying layer.
123+
124+
quantized_weights = []
125+
for unquantized_weight, quantizer, min_var, max_var in self._weight_vars:
126+
127+
def make_quantizer_fn(training):
128+
"""Use currying to return True/False specialized fns to the cond."""
129+
130+
def quantizer_fn(unquantized_weight=unquantized_weight,
131+
quantizer=quantizer,
132+
min_var=min_var,
133+
max_var=max_var):
134+
return quantizer(unquantized_weight, self.optimizer_step, training,
135+
**self._dict_vars(min_var, max_var))
136+
137+
return quantizer_fn
138+
139+
quantized_weight = tf_utils.smart_cond(
140+
training, make_quantizer_fn(True), make_quantizer_fn(False))
141+
quantized_weights.append(quantized_weight)
142+
143+
self.quantize_provider.set_quantize_weights(self.layer, quantized_weights)
144+
145+
# Replace all activations with `QuantizeAwareActivation`s which can
146+
# quantize activation tensors during graph construction.
147+
148+
for quantize_activation in self._quantize_activations:
149+
quantize_activation.training = training
150+
151+
self.quantize_provider.set_quantize_activations(
152+
self.layer, self._quantize_activations)
153+
154+
return self.layer.call(inputs)
155+
156+
def get_config(self):
157+
base_config = super(QuantizeWrapper, self).get_config()
158+
config = {'quantize_provider': self.quantize_provider}
159+
return dict(list(base_config.items()) + list(config.items()))
160+
161+
@classmethod
162+
def from_config(cls, config):
163+
config = config.copy()
164+
165+
quantize_provider = config.pop('quantize_provider')
166+
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object # pylint: disable=g-import-not-at-top
167+
# TODO(pulkitb): Add all known `QuantizeProvider`s to custom_objects
168+
custom_objects = {
169+
'QuantizeProvider': quantize_provider_mod.QuantizeProvider
170+
}
171+
config['quantize_provider'] = deserialize_keras_object(
172+
quantize_provider,
173+
module_objects=globals(),
174+
custom_objects=custom_objects)
175+
176+
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
177+
layer = deserialize_layer(config.pop('layer'))
178+
config['layer'] = layer
179+
180+
return cls(**config)
181+
182+
@property
183+
def trainable(self):
184+
return self.layer.trainable
185+
186+
@trainable.setter
187+
def trainable(self, value):
188+
self.layer.trainable = value
189+
190+
@property
191+
def trainable_weights(self):
192+
return self.layer.trainable_weights + self._trainable_weights
193+
194+
@property
195+
def non_trainable_weights(self):
196+
return self.layer.non_trainable_weights + self._non_trainable_weights
197+
198+
@property
199+
def updates(self):
200+
return self.layer.updates + self._updates
201+
202+
@property
203+
def losses(self):
204+
return self.layer.losses + self._losses

0 commit comments

Comments
 (0)