|
20 | 20 |
|
21 | 21 | from tensorflow.python.keras import activations
|
22 | 22 | from tensorflow.python.keras import initializers
|
| 23 | +from tensorflow.python.keras.utils import tf_utils |
23 | 24 |
|
24 | 25 |
|
25 | 26 | class QuantizeAwareActivation(object):
|
@@ -63,8 +64,6 @@ def __init__(self, activation, quantizer, step, quantize_wrapper):
|
63 | 64 | self.step = step
|
64 | 65 | self.quantize_wrapper = quantize_wrapper
|
65 | 66 |
|
66 |
| - self._training = False |
67 |
| - |
68 | 67 | if self._should_pre_quantize():
|
69 | 68 | self._min_pre_activation, self._max_pre_activation = \
|
70 | 69 | self._add_range_weights('pre_activation')
|
@@ -95,22 +94,39 @@ def training(self):
|
95 | 94 | def training(self, value):
|
96 | 95 | self._training = value
|
97 | 96 |
|
| 97 | + def _dict_vars(self, min_var, max_var): |
| 98 | + return {'min_var': min_var, 'max_var': max_var} |
| 99 | + |
98 | 100 | def __call__(self, inputs, *args, **kwargs):
|
99 |
| - # TODO(pulkitb): Add cond here to handle training properly. |
| 101 | + |
| 102 | + def make_quantizer_fn(training, x, min_var, max_var): |
| 103 | + """Use currying to return True/False specialized fns to the cond.""" |
| 104 | + |
| 105 | + def quantizer_fn(x=x, |
| 106 | + quantizer=self.quantizer, |
| 107 | + min_var=min_var, |
| 108 | + max_var=max_var): |
| 109 | + return quantizer(x, self.step, training, |
| 110 | + **self._dict_vars(min_var, max_var)) |
| 111 | + |
| 112 | + return quantizer_fn |
| 113 | + |
100 | 114 | x = inputs
|
101 | 115 | if self._should_pre_quantize():
|
102 |
| - x = self.quantizer( |
103 |
| - x, self.step, self._training, **{ |
104 |
| - 'min_var': self._min_pre_activation, |
105 |
| - 'max_var': self._max_pre_activation |
106 |
| - }) |
| 116 | + x = tf_utils.smart_cond( |
| 117 | + self._training, |
| 118 | + make_quantizer_fn(True, x, self._min_pre_activation, |
| 119 | + self._max_pre_activation), |
| 120 | + make_quantizer_fn(False, x, self._min_pre_activation, |
| 121 | + self._max_pre_activation)) |
107 | 122 |
|
108 | 123 | x = self.activation(x, *args, **kwargs)
|
109 | 124 |
|
110 |
| - x = self.quantizer( |
111 |
| - x, self.step, self._training, **{ |
112 |
| - 'min_var': self._min_post_activation, |
113 |
| - 'max_var': self._max_post_activation |
114 |
| - }) |
| 125 | + x = tf_utils.smart_cond( |
| 126 | + self._training, |
| 127 | + make_quantizer_fn(True, x, self._min_post_activation, |
| 128 | + self._max_post_activation), |
| 129 | + make_quantizer_fn(False, x, self._min_post_activation, |
| 130 | + self._max_post_activation)) |
115 | 131 |
|
116 | 132 | return x
|
0 commit comments