|
24 | 24 | from tensorflow.python.platform import test
|
25 | 25 |
|
26 | 26 | from tensorflow_model_optimization.python.core.quantization.keras import quantize_annotate
|
| 27 | +from tensorflow_model_optimization.python.core.quantization.keras import quantize_emulatable_layer |
| 28 | + |
| 29 | +QuantizeEmulatableLayer = quantize_emulatable_layer.QuantizeEmulatableLayer |
27 | 30 |
|
28 | 31 |
|
29 | 32 | class QuantizeAnnotateTest(test.TestCase):
|
30 | 33 |
|
31 |
| - def testAppliesWrapperToAllClasses(self): |
32 |
| - layer = keras.layers.Dense(5, activation='relu', input_shape=(10,)) |
| 34 | + def setUp(self): |
| 35 | + self.quant_params = { |
| 36 | + 'num_bits': 8, |
| 37 | + 'narrow_range': True, |
| 38 | + 'symmetric': True |
| 39 | + } |
| 40 | + |
| 41 | + def testRaisesErrorForUnsupportedLayer(self): |
| 42 | + class CustomLayer(keras.layers.Dense): |
| 43 | + pass |
| 44 | + |
| 45 | + with self.assertRaises(ValueError): |
| 46 | + quantize_annotate.QuantizeAnnotate(CustomLayer(10), **self.quant_params) |
| 47 | + |
| 48 | + def testAnnotatesCustomQuantizableLayer(self): |
| 49 | + class CustomLayerQuantizable(keras.layers.Dense, QuantizeEmulatableLayer): |
| 50 | + def get_quantizable_weights(self): # pylint: disable=g-wrong-blank-lines |
| 51 | + return [self.kernel] |
| 52 | + |
| 53 | + def set_quantizable_weights(self, weights): |
| 54 | + self.kernel = weights[0] |
33 | 55 |
|
| 56 | + annotated_layer = quantize_annotate.QuantizeAnnotate( |
| 57 | + CustomLayerQuantizable(10), **self.quant_params) |
| 58 | + |
| 59 | + self.assertIsInstance(annotated_layer.layer, CustomLayerQuantizable) |
| 60 | + self.assertEqual( |
| 61 | + self.quant_params, annotated_layer.get_quantize_params()) |
| 62 | + |
| 63 | + def testAnnotatesKerasLayer(self): |
| 64 | + layer = keras.layers.Dense(5, activation='relu', input_shape=(10,)) |
34 | 65 | model = keras.Sequential([layer])
|
35 |
| - wrapped_model = keras.Sequential([ |
| 66 | + |
| 67 | + annotated_model = keras.Sequential([ |
36 | 68 | quantize_annotate.QuantizeAnnotate(
|
37 |
| - layer, num_bits=8, input_shape=(10,))]) |
| 69 | + layer, input_shape=(10,), **self.quant_params)]) |
| 70 | + |
| 71 | + annotated_layer = annotated_model.layers[0] |
| 72 | + self.assertIsInstance(annotated_layer.layer, keras.layers.Dense) |
| 73 | + self.assertEqual( |
| 74 | + self.quant_params, annotated_layer.get_quantize_params()) |
38 | 75 |
|
| 76 | + # Annotated model should not affect computation. Returns same results. |
39 | 77 | x_test = np.random.rand(10, 10)
|
40 |
| - self.assertAllEqual(model.predict(x_test), wrapped_model.predict(x_test)) |
| 78 | + self.assertAllEqual(model.predict(x_test), annotated_model.predict(x_test)) |
41 | 79 |
|
42 | 80 | if __name__ == '__main__':
|
43 | 81 | test.main()
|
0 commit comments