27
27
from tensorflow_model_optimization .python .core .quantization .keras import quantize_aware_activation
28
28
from tensorflow_model_optimization .python .core .quantization .keras import quantize_emulate
29
29
from tensorflow_model_optimization .python .core .quantization .keras import quantize_emulate_wrapper
30
+ from tensorflow_model_optimization .python .core .quantization .keras import quantize_provider as quantize_provider_mod
30
31
31
32
quantize_annotate = quantize_emulate .quantize_annotate
32
33
QuantizeEmulate = quantize_emulate .QuantizeEmulate
@@ -67,35 +68,20 @@ def testQuantizeEmulateList(self):
67
68
68
69
class QuantizeAnnotateTest (test .TestCase ):
69
70
70
- def setUp (self ):
71
- self .quant_params = {
72
- 'num_bits' : 8 ,
73
- 'narrow_range' : True ,
74
- 'symmetric' : True
75
- }
76
-
77
- def _assertQuantParams (self , layer , quant_params ):
78
- layer_params = {
79
- 'num_bits' : layer ._num_bits ,
80
- 'narrow_range' : layer ._narrow_range ,
81
- 'symmetric' : layer ._symmetric
82
- }
83
- self .assertEqual (quant_params , layer_params )
84
-
85
- def _assertWrappedLayer (self , layer , quant_params ):
71
+ def _assertWrappedLayer (self , layer , quantize_provider = None ):
86
72
self .assertIsInstance (layer , quant_annotate .QuantizeAnnotate )
87
- self ._assertQuantParams ( layer , quant_params )
73
+ self .assertEqual ( quantize_provider , layer . quantize_provider )
88
74
89
- def _assertWrappedSequential (self , model , quant_params ):
75
+ def _assertWrappedModel (self , model ):
90
76
for layer in model .layers :
91
- self ._assertWrappedLayer (layer , quant_params )
77
+ self ._assertWrappedLayer (layer )
92
78
93
79
def testQuantizeAnnotateLayer (self ):
94
80
layer = keras .layers .Dense (10 , input_shape = (5 ,))
95
81
wrapped_layer = quantize_annotate (
96
- layer , input_shape = (5 ,), ** self . quant_params )
82
+ layer , input_shape = (5 ,))
97
83
98
- self ._assertWrappedLayer (wrapped_layer , self . quant_params )
84
+ self ._assertWrappedLayer (wrapped_layer )
99
85
100
86
inputs = np .random .rand (1 , 5 )
101
87
model = keras .Sequential ([layer ])
@@ -110,24 +96,33 @@ def testQuantizeAnnotateModel(self):
110
96
keras .layers .Dense (10 , input_shape = (5 ,)),
111
97
keras .layers .Dropout (0.4 )
112
98
])
113
- annotated_model = quantize_annotate (model , ** self . quant_params )
99
+ annotated_model = quantize_annotate (model )
114
100
115
- self ._assertWrappedSequential (annotated_model , self . quant_params )
101
+ self ._assertWrappedModel (annotated_model )
116
102
117
103
inputs = np .random .rand (1 , 5 )
118
104
self .assertAllEqual (model .predict (inputs ), annotated_model .predict (inputs ))
119
105
120
106
def testQuantizeAnnotateModel_HasAnnotatedLayers (self ):
121
- layer_params = {'num_bits' : 4 , 'narrow_range' : False , 'symmetric' : False }
107
+ class TestQuantizeProvider (quantize_provider_mod .QuantizeProvider ):
108
+
109
+ def get_weights_and_quantizers (self , layer ):
110
+ pass
111
+
112
+ def get_activations_and_quantizers (self , layer ):
113
+ pass
114
+
115
+ quantize_provider = TestQuantizeProvider ()
116
+
122
117
model = keras .Sequential ([
123
118
keras .layers .Dense (10 , input_shape = (5 ,)),
124
- quantize_annotate (keras .layers .Dense (5 ), ** layer_params )
119
+ quant_annotate .QuantizeAnnotate (
120
+ keras .layers .Dense (5 ), quantize_provider = quantize_provider )
125
121
])
122
+ annotated_model = quantize_annotate (model )
126
123
127
- annotated_model = quantize_annotate (model , ** self .quant_params )
128
-
129
- self ._assertWrappedLayer (annotated_model .layers [0 ], self .quant_params )
130
- self ._assertWrappedLayer (annotated_model .layers [1 ], layer_params )
124
+ self ._assertWrappedLayer (annotated_model .layers [0 ])
125
+ self ._assertWrappedLayer (annotated_model .layers [1 ], quantize_provider )
131
126
# Ensure an already annotated layer is not wrapped again.
132
127
self .assertIsInstance (annotated_model .layers [1 ].layer , keras .layers .Dense )
133
128
@@ -181,7 +176,7 @@ def testRaisesErrorNoAnnotatedLayers_Functional(self):
181
176
182
177
def testRaisesErrorModelNotBuilt (self ):
183
178
model = keras .Sequential ([
184
- quantize_annotate (keras .layers .Dense (10 ), ** self . quant_params1 )])
179
+ quantize_annotate (keras .layers .Dense (10 ))])
185
180
186
181
self .assertFalse (model .built )
187
182
with self .assertRaises (ValueError ):
@@ -191,16 +186,15 @@ def testRaisesErrorModelNotBuilt(self):
191
186
192
187
def _get_annotated_sequential_model (self ):
193
188
return keras .Sequential ([
194
- quantize_annotate (keras .layers .Conv2D (32 , 5 ), input_shape = (28 , 28 , 1 ),
195
- ** self .quant_params1 ),
196
- quantize_annotate (keras .layers .Dense (10 ), ** self .quant_params2 )
189
+ quantize_annotate (keras .layers .Conv2D (32 , 5 ), input_shape = (28 , 28 , 1 )),
190
+ quantize_annotate (keras .layers .Dense (10 ))
197
191
])
198
192
199
193
def _get_annotated_functional_model (self ):
200
194
inputs = keras .Input (shape = (28 , 28 , 1 ))
201
195
x = quantize_annotate (
202
- keras .layers .Conv2D (32 , 5 ), ** self . quant_params1 )(inputs )
203
- results = quantize_annotate (keras .layers .Dense (10 ), ** self . quant_params2 )(x )
196
+ keras .layers .Conv2D (32 , 5 ))(inputs )
197
+ results = quantize_annotate (keras .layers .Dense (10 ))(x )
204
198
205
199
return keras .Model (inputs = inputs , outputs = results )
206
200
@@ -283,8 +277,7 @@ def testQuantizesActivationsWithinLayer_Sequential(self):
283
277
model = keras .Sequential ([
284
278
quantize_annotate (
285
279
keras .layers .Conv2D (32 , 5 , activation = 'relu' ),
286
- input_shape = (28 , 28 , 1 ),
287
- ** quant_params )
280
+ input_shape = (28 , 28 , 1 ))
288
281
])
289
282
290
283
quantized_model = quantize_emulate .quantize_apply (model )
@@ -304,8 +297,7 @@ def testQuantizesActivationsWithinLayer_Functional(self):
304
297
305
298
inputs = keras .Input (shape = (28 , 28 , 1 ))
306
299
results = quantize_annotate (
307
- keras .layers .Conv2D (32 , 5 , activation = 'relu' ),
308
- ** self .quant_params1 )(inputs )
300
+ keras .layers .Conv2D (32 , 5 , activation = 'relu' ))(inputs )
309
301
model = keras .Model (inputs = inputs , outputs = results )
310
302
311
303
quantized_model = quantize_emulate .quantize_apply (model )
0 commit comments