Skip to content

Commit a0b2291

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Modify QuantizeAwareActivation to work with new API.
QuantizeAwareActivation now functions as a callable which can be created by the parent wrapper. This callable can be substituded in place of actual activations and can insert quant operations appropriately. PiperOrigin-RevId: 264439215
1 parent 110b6c9 commit a0b2291

File tree

3 files changed

+81
-98
lines changed

3 files changed

+81
-98
lines changed

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ py_test(
229229
visibility = ["//visibility:public"],
230230
deps = [
231231
":quantize_aware_activation",
232+
":quantizers",
232233
# tensorflow dep1,
233234
# python/keras tensorflow dep2,
234235
],

tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation.py

Lines changed: 53 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,10 @@
1919
from __future__ import print_function
2020

2121
from tensorflow.python.keras import activations
22-
from tensorflow.python.keras import backend as K
2322
from tensorflow.python.keras import initializers
24-
from tensorflow.python.keras.layers import Layer
2523

26-
from tensorflow_model_optimization.python.core.quantization.keras import quant_ops
2724

28-
29-
class QuantizeAwareActivation(Layer):
25+
class QuantizeAwareActivation(object):
3026
"""Activation layer for quantization aware training.
3127
3228
The goal of this layer is to apply quantize operations during training such
@@ -51,107 +47,70 @@ class QuantizeAwareActivation(Layer):
5147

5248
_PRE_ACTIVATION_TYPES = {'softmax'}
5349

54-
def __init__(
55-
self,
56-
activation,
57-
parent_layer,
58-
num_bits,
59-
symmetric=True,
60-
**kwargs):
50+
def __init__(self, activation, quantizer, step, quantize_wrapper):
6151
"""Construct a QuantizeAwareActivation layer.
6252
6353
Args:
64-
activation: Activation function to use.
65-
If you don't specify anything, no activation is applied
54+
activation: Activation function to use. If you don't specify anything, no
55+
activation is applied
6656
(ie. "linear" activation: `a(x) = x`).
67-
parent_layer: The layer this activation is being applied to. Such
68-
as Conv2D, Dense etc.
69-
num_bits: Number of bits for quantization
70-
symmetric: If true, use symmetric quantization limits instead of training
71-
the minimum and maximum of each quantization range separately.
72-
**kwargs: Additional keyword arguments to be passed to the keras layer.
57+
quantizer: `Quantizer` to be used to quantize the activation.
58+
step: Variable which tracks optimizer step.
59+
quantize_wrapper: `QuantizeWrapper` which owns this activation.
7360
"""
74-
super(QuantizeAwareActivation, self).__init__(**kwargs)
75-
7661
self.activation = activations.get(activation)
77-
self.parent_layer = parent_layer
62+
self.quantizer = quantizer
63+
self.step = step
64+
self.quantize_wrapper = quantize_wrapper
65+
66+
self._training = False
7867

79-
self.num_bits = num_bits
80-
self.symmetric = symmetric
68+
if self._should_pre_quantize():
69+
self._min_pre_activation, self._max_pre_activation = \
70+
self._add_range_weights('pre_activation')
8171

82-
# TODO(pulkitb): Generate a meaningful name for this layer, which
83-
# ideally also includes the parent layer.
72+
self._min_post_activation, self._max_post_activation = \
73+
self._add_range_weights('post_activation')
8474

85-
def _requires_pre_quant(self):
86-
# TODO(pulkitb): Make this more sophisticated. This should match the
87-
# implementation of kernels on-device.
75+
def _should_pre_quantize(self):
76+
# TODO(pulkitb): Add logic to deduce whether we should pre-quantize.
77+
# Whether we apply quantize operations around activations depends on the
78+
# implementation of the specific kernel. For example, ReLUs are fused in
79+
# whereas Softmax ops are not. Should linear have post-quantize?
8880
return self.activation.__name__ in self._PRE_ACTIVATION_TYPES
8981

90-
def build(self, input_shape):
91-
if self._requires_pre_quant():
92-
self._min_pre_activation = self.add_variable(
93-
'min_pre_activation',
94-
initializer=initializers.Constant(-6.0),
95-
trainable=False)
96-
self._max_pre_activation = self.add_variable(
97-
'max_pre_activation',
98-
initializer=initializers.Constant(6.0),
99-
trainable=False)
100-
101-
self._min_post_activation = self.add_variable(
102-
'min_post_activation',
103-
initializer=initializers.Constant(-6.0),
104-
trainable=False)
105-
self._max_post_activation = self.add_variable(
106-
'max_post_activation',
107-
initializer=initializers.Constant(6.0),
108-
trainable=False)
109-
110-
def call(self, inputs, training=None):
111-
# TODO(pulkitb): Construct graph for both training/eval modes.
112-
if training is None:
113-
training = K.learning_phase()
82+
def _add_range_weights(self, name):
83+
min_var = self.quantize_wrapper.add_weight(
84+
name + '_min', initializer=initializers.Constant(-6.0), trainable=False)
85+
max_var = self.quantize_wrapper.add_weight(
86+
name + '_max', initializer=initializers.Constant(6.0), trainable=False)
87+
88+
return min_var, max_var
89+
90+
@property
91+
def training(self):
92+
return self._training
11493

94+
@training.setter
95+
def training(self, value):
96+
self._training = value
97+
98+
def __call__(self, inputs, *args, **kwargs):
99+
# TODO(pulkitb): Add cond here to handle training properly.
115100
x = inputs
116-
if self._requires_pre_quant():
117-
x = quant_ops.MovingAvgQuantize(
118-
inputs,
119-
self._min_pre_activation,
120-
self._max_pre_activation,
121-
ema_decay=0.999,
122-
is_training=training,
123-
num_bits=self.num_bits,
124-
symmetric=self.symmetric,
125-
name_prefix=self.name)
126-
127-
x = self.activation(x)
128-
x = quant_ops.MovingAvgQuantize(
129-
x,
130-
self._min_post_activation,
131-
self._max_post_activation,
132-
ema_decay=0.999,
133-
is_training=training,
134-
num_bits=self.num_bits,
135-
symmetric=self.symmetric,
136-
name_prefix=self.name)
101+
if self._should_pre_quantize():
102+
x = self.quantizer(
103+
x, self.step, self._training, **{
104+
'min_var': self._min_pre_activation,
105+
'max_var': self._max_pre_activation
106+
})
137107

138-
return x
108+
x = self.activation(x, *args, **kwargs)
139109

140-
def get_quantize_params(self):
141-
return {
142-
'num_bits': self.num_bits,
143-
'symmetric': self.symmetric,
144-
}
145-
146-
def compute_output_shape(self, input_shape):
147-
return input_shape
148-
149-
def get_config(self):
150-
base_config = super(QuantizeAwareActivation, self).get_config()
151-
config = {
152-
'activation': activations.serialize(self.activation),
153-
'parent_layer': self.parent_layer,
154-
'num_bits': self.num_bits,
155-
'symmetric': self.symmetric,
156-
}
157-
return dict(list(base_config.items()) + list(config.items()))
110+
x = self.quantizer(
111+
x, self.step, self._training, **{
112+
'min_var': self._min_post_activation,
113+
'max_var': self._max_post_activation
114+
})
115+
116+
return x

tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,37 @@
2121
import numpy as np
2222

2323
from tensorflow.python import keras
24+
from tensorflow.python.keras import activations
2425
from tensorflow.python.platform import test
2526

2627
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
28+
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
2729

2830
QuantizeAwareActivation = quantize_aware_activation.QuantizeAwareActivation
31+
MovingAverageQuantizer = quantizers.MovingAverageQuantizer
2932

3033

3134
class QuantizeAwareQuantizationTest(test.TestCase):
3235

36+
def setUp(self):
37+
super(QuantizeAwareQuantizationTest, self).setUp()
38+
self.quantizer = MovingAverageQuantizer(
39+
num_bits=8, per_axis=False, symmetric=True)
40+
41+
class TestLayer(keras.layers.Layer):
42+
43+
def call(self, inputs):
44+
return self.activation(inputs)
45+
46+
def compute_output_shape(self, input_shape):
47+
return input_shape
48+
3349
def testAppliesQuantizationPostActivation(self):
34-
model = keras.Sequential([
35-
QuantizeAwareActivation('relu', 'dense', num_bits=8)])
50+
layer = self.TestLayer()
51+
layer.activation = QuantizeAwareActivation(
52+
activations.get('relu'), self.quantizer, 0, layer)
53+
54+
model = keras.Sequential([layer])
3655

3756
x = np.array([-6.0, -3.0, 0.0, 0.05, 0.1, 3.0, 6.0])
3857
# All negative values are removed due to ReLU. The other expected values
@@ -46,8 +65,11 @@ def testAppliesQuantizationPostActivation(self):
4665
self.assertAllClose(expected_activation, model.predict(x))
4766

4867
def testAppliesQuantizationPreAndPostActivation(self):
49-
model = keras.Sequential([
50-
QuantizeAwareActivation('softmax', 'dense', num_bits=8)])
68+
layer = self.TestLayer()
69+
layer.activation = QuantizeAwareActivation(
70+
activations.get('softmax'), self.quantizer, 0, layer)
71+
72+
model = keras.Sequential([layer])
5173

5274
x = np.array([[1.0, 2.0]])
5375
# expected_activation is determined using the float buckets when [-6, 6] is
@@ -61,5 +83,6 @@ def testAppliesQuantizationPreAndPostActivation(self):
6183

6284
self.assertAllClose(expected_activation, model.predict(x))
6385

86+
6487
if __name__ == '__main__':
6588
test.main()

0 commit comments

Comments
 (0)