Skip to content

Commit 624155b

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Add training support to QuantizeAwareActivation
Ensures the __call__ method constructs quantization operations while taking `training` into account. PiperOrigin-RevId: 264455646
1 parent ea8e8d4 commit 624155b

File tree

2 files changed

+35
-14
lines changed

2 files changed

+35
-14
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from tensorflow.python.keras import activations
2222
from tensorflow.python.keras import initializers
23+
from tensorflow.python.keras.utils import tf_utils
2324

2425

2526
class QuantizeAwareActivation(object):
@@ -63,8 +64,6 @@ def __init__(self, activation, quantizer, step, quantize_wrapper):
6364
self.step = step
6465
self.quantize_wrapper = quantize_wrapper
6566

66-
self._training = False
67-
6867
if self._should_pre_quantize():
6968
self._min_pre_activation, self._max_pre_activation = \
7069
self._add_range_weights('pre_activation')
@@ -95,22 +94,39 @@ def training(self):
9594
def training(self, value):
9695
self._training = value
9796

97+
def _dict_vars(self, min_var, max_var):
98+
return {'min_var': min_var, 'max_var': max_var}
99+
98100
def __call__(self, inputs, *args, **kwargs):
99-
# TODO(pulkitb): Add cond here to handle training properly.
101+
102+
def make_quantizer_fn(training, x, min_var, max_var):
103+
"""Use currying to return True/False specialized fns to the cond."""
104+
105+
def quantizer_fn(x=x,
106+
quantizer=self.quantizer,
107+
min_var=min_var,
108+
max_var=max_var):
109+
return quantizer(x, self.step, training,
110+
**self._dict_vars(min_var, max_var))
111+
112+
return quantizer_fn
113+
100114
x = inputs
101115
if self._should_pre_quantize():
102-
x = self.quantizer(
103-
x, self.step, self._training, **{
104-
'min_var': self._min_pre_activation,
105-
'max_var': self._max_pre_activation
106-
})
116+
x = tf_utils.smart_cond(
117+
self._training,
118+
make_quantizer_fn(True, x, self._min_pre_activation,
119+
self._max_pre_activation),
120+
make_quantizer_fn(False, x, self._min_pre_activation,
121+
self._max_pre_activation))
107122

108123
x = self.activation(x, *args, **kwargs)
109124

110-
x = self.quantizer(
111-
x, self.step, self._training, **{
112-
'min_var': self._min_post_activation,
113-
'max_var': self._max_post_activation
114-
})
125+
x = tf_utils.smart_cond(
126+
self._training,
127+
make_quantizer_fn(True, x, self._min_post_activation,
128+
self._max_post_activation),
129+
make_quantizer_fn(False, x, self._min_post_activation,
130+
self._max_post_activation))
115131

116132
return x

tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from tensorflow.python import keras
2424
from tensorflow.python.keras import activations
25+
from tensorflow.python.keras import backend as K
2526
from tensorflow.python.platform import test
2627

2728
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
@@ -40,7 +41,11 @@ def setUp(self):
4041

4142
class TestLayer(keras.layers.Layer):
4243

43-
def call(self, inputs):
44+
def call(self, inputs, training=None):
45+
if training is None:
46+
training = K.learning_phase()
47+
48+
self.activation.training = training
4449
return self.activation(inputs)
4550

4651
def compute_output_shape(self, input_shape):

0 commit comments

Comments
 (0)