|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | # ==============================================================================
|
15 |
| -"""Tests for keras pruning wrapper.""" |
| 15 | +"""Tests for quantize API functions.""" |
16 | 16 |
|
17 | 17 | from __future__ import absolute_import
|
18 | 18 | from __future__ import division
|
|
23 | 23 | from tensorflow.python import keras
|
24 | 24 | from tensorflow.python.platform import test
|
25 | 25 | from tensorflow_model_optimization.python.core.quantization.keras import quantize_annotate as quant_annotate
|
| 26 | +from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation |
26 | 27 | from tensorflow_model_optimization.python.core.quantization.keras import quantize_emulate
|
27 |
| -from tensorflow_model_optimization.python.core.quantization.keras.quantize_emulate import QuantizeEmulate |
28 |
| -from tensorflow_model_optimization.python.core.quantization.keras.quantize_emulate_wrapper import QuantizeEmulateWrapper |
| 28 | +from tensorflow_model_optimization.python.core.quantization.keras import quantize_emulate_wrapper |
29 | 29 |
|
30 | 30 | quantize_annotate = quantize_emulate.quantize_annotate
|
| 31 | +QuantizeEmulate = quantize_emulate.QuantizeEmulate |
| 32 | +QuantizeEmulateWrapper = quantize_emulate_wrapper.QuantizeEmulateWrapper |
| 33 | +QuantizeAwareActivation = quantize_aware_activation.QuantizeAwareActivation |
31 | 34 |
|
32 | 35 |
|
33 | 36 | class QuantizeEmulateTest(test.TestCase):
|
@@ -131,5 +134,155 @@ def testQuantizeAnnotateModel_HasAnnotatedLayers(self):
|
131 | 134 | self.assertAllEqual(model.predict(inputs), annotated_model.predict(inputs))
|
132 | 135 |
|
133 | 136 |
|
| 137 | +class QuantizeApplyTest(test.TestCase): |
| 138 | + |
| 139 | + def setUp(self): |
| 140 | + self.quant_params1 = { |
| 141 | + 'num_bits': 8, |
| 142 | + 'narrow_range': True, |
| 143 | + 'symmetric': True |
| 144 | + } |
| 145 | + self.quant_params2 = { |
| 146 | + 'num_bits': 4, |
| 147 | + 'narrow_range': False, |
| 148 | + 'symmetric': False |
| 149 | + } |
| 150 | + |
| 151 | + # Validation tests |
| 152 | + |
| 153 | + def testRaisesErrorIfNotKerasModel(self): |
| 154 | + with self.assertRaises(ValueError): |
| 155 | + quantize_emulate.quantize_apply(keras.layers.Dense(32)) |
| 156 | + |
| 157 | + def testRaisesErrorIfKerasSubclassedModel(self): |
| 158 | + class MyModel(keras.Model): |
| 159 | + def call(self, inputs, training=None, mask=None): # pylint: disable=g-wrong-blank-lines |
| 160 | + return inputs |
| 161 | + |
| 162 | + with self.assertRaises(ValueError): |
| 163 | + quantize_emulate.quantize_apply(MyModel()) |
| 164 | + |
| 165 | + def testRaisesErrorNoAnnotatedLayers_Sequential(self): |
| 166 | + model = keras.Sequential([ |
| 167 | + keras.layers.Dense(10), keras.layers.Dropout(0.4)]) |
| 168 | + |
| 169 | + with self.assertRaises(ValueError): |
| 170 | + quantize_emulate.quantize_apply(model) |
| 171 | + |
| 172 | + def testRaisesErrorNoAnnotatedLayers_Functional(self): |
| 173 | + inputs = keras.Input(shape=(10,)) |
| 174 | + x = keras.layers.Dense(32, activation='relu')(inputs) |
| 175 | + results = keras.layers.Dense(5, activation='softmax')(x) |
| 176 | + model = keras.Model(inputs=inputs, outputs=results) |
| 177 | + |
| 178 | + with self.assertRaises(ValueError): |
| 179 | + quantize_emulate.quantize_apply(model) |
| 180 | + |
| 181 | + # Quantization Apply Tests |
| 182 | + |
| 183 | + def _get_annotated_sequential_model(self): |
| 184 | + return keras.Sequential([ |
| 185 | + quantize_annotate(keras.layers.Conv2D(32, 5), input_shape=(28, 28, 1), |
| 186 | + **self.quant_params1), |
| 187 | + quantize_annotate(keras.layers.Dense(10), **self.quant_params2) |
| 188 | + ]) |
| 189 | + |
| 190 | + def _get_annotated_functional_model(self): |
| 191 | + inputs = keras.Input(shape=(28, 28, 1)) |
| 192 | + x = quantize_annotate( |
| 193 | + keras.layers.Conv2D(32, 5), **self.quant_params1)(inputs) |
| 194 | + results = quantize_annotate(keras.layers.Dense(10), **self.quant_params2)(x) |
| 195 | + |
| 196 | + return keras.Model(inputs=inputs, outputs=results) |
| 197 | + |
| 198 | + def _assert_layer_emulated( |
| 199 | + self, annotated_layer, emulated_layer, exclude_keys=None): |
| 200 | + self.assertIsInstance(emulated_layer, QuantizeEmulateWrapper) |
| 201 | + |
| 202 | + self.assertEqual(annotated_layer.get_quantize_params(), |
| 203 | + emulated_layer.get_quantize_params()) |
| 204 | + |
| 205 | + # Extract configs of the inner layers they wrap. |
| 206 | + annotated_config = annotated_layer.layer.get_config() |
| 207 | + emulated_config = emulated_layer.layer.get_config() |
| 208 | + |
| 209 | + # The underlying layers aren't always exactly the same. For example, |
| 210 | + # activations in the underlying layers might be replaced. Exclude keys |
| 211 | + # if required. |
| 212 | + if exclude_keys: |
| 213 | + for key in exclude_keys: |
| 214 | + annotated_config.pop(key) |
| 215 | + emulated_config.pop(key) |
| 216 | + |
| 217 | + self.assertEqual(annotated_config, emulated_config) |
| 218 | + |
| 219 | + def _assert_model_emulated( |
| 220 | + self, annotated_model, emulated_model, exclude_keys=None): |
| 221 | + for annotated_layer, emulated_layer in zip(annotated_model.layers, |
| 222 | + emulated_model.layers): |
| 223 | + if isinstance(emulated_layer, keras.layers.InputLayer): |
| 224 | + continue |
| 225 | + |
| 226 | + self._assert_layer_emulated(annotated_layer, emulated_layer, exclude_keys) |
| 227 | + |
| 228 | + def testAppliesQuantizationToAnnotatedModel_Sequential(self): |
| 229 | + model = self._get_annotated_sequential_model() |
| 230 | + |
| 231 | + quantized_model = quantize_emulate.quantize_apply(model) |
| 232 | + |
| 233 | + self._assert_model_emulated(model, quantized_model) |
| 234 | + |
| 235 | + def testAppliesQuantizationToAnnotatedModel_Functional(self): |
| 236 | + model = self._get_annotated_functional_model() |
| 237 | + |
| 238 | + quantized_model = quantize_emulate.quantize_apply(model) |
| 239 | + |
| 240 | + self._assert_model_emulated(model, quantized_model) |
| 241 | + |
| 242 | + # Transformation Tests |
| 243 | + |
| 244 | + def testQuantizesActivationsWithinLayer_Sequential(self): |
| 245 | + quant_params = {'num_bits': 8, 'symmetric': True} |
| 246 | + model = keras.Sequential([ |
| 247 | + quantize_annotate( |
| 248 | + keras.layers.Conv2D(32, 5, activation='relu'), |
| 249 | + input_shape=(28, 28, 1), |
| 250 | + **quant_params) |
| 251 | + ]) |
| 252 | + |
| 253 | + quantized_model = quantize_emulate.quantize_apply(model) |
| 254 | + |
| 255 | + # We expect activation to be modified. |
| 256 | + self._assert_model_emulated(model, quantized_model, ['activation']) |
| 257 | + |
| 258 | + conv_layer = quantized_model.layers[0].layer |
| 259 | + self.assertIsInstance(conv_layer.activation, QuantizeAwareActivation) |
| 260 | + self.assertEqual( |
| 261 | + keras.activations.get('relu'), conv_layer.activation.activation) |
| 262 | + self.assertEqual(keras.layers.Conv2D, conv_layer.activation.parent_layer) |
| 263 | + self.assertEqual(quant_params, conv_layer.activation.get_quantize_params()) |
| 264 | + |
| 265 | + def testQuantizesActivationsWithinLayer_Functional(self): |
| 266 | + quant_params = {'num_bits': 8, 'symmetric': True} |
| 267 | + |
| 268 | + inputs = keras.Input(shape=(28, 28, 1)) |
| 269 | + results = quantize_annotate( |
| 270 | + keras.layers.Conv2D(32, 5, activation='relu'), |
| 271 | + **self.quant_params1)(inputs) |
| 272 | + model = keras.Model(inputs=inputs, outputs=results) |
| 273 | + |
| 274 | + quantized_model = quantize_emulate.quantize_apply(model) |
| 275 | + |
| 276 | + # We expect activation to be modified. |
| 277 | + self._assert_model_emulated(model, quantized_model, ['activation']) |
| 278 | + |
| 279 | + conv_layer = quantized_model.layers[1].layer |
| 280 | + self.assertIsInstance(conv_layer.activation, QuantizeAwareActivation) |
| 281 | + self.assertEqual( |
| 282 | + keras.activations.get('relu'), conv_layer.activation.activation) |
| 283 | + self.assertEqual(keras.layers.Conv2D, conv_layer.activation.parent_layer) |
| 284 | + self.assertEqual(quant_params, conv_layer.activation.get_quantize_params()) |
| 285 | + |
| 286 | + |
134 | 287 | if __name__ == '__main__':
|
135 | 288 | test.main()
|
0 commit comments