21
21
import numpy as np
22
22
23
23
from tensorflow .python import keras
24
+ from tensorflow .python .keras import backend as K
24
25
from tensorflow .python .platform import test
25
26
from tensorflow_model_optimization .python .core .quantization .keras import quantize_annotate as quant_annotate
26
27
from tensorflow_model_optimization .python .core .quantization .keras import quantize_aware_activation
@@ -195,6 +196,20 @@ def _get_annotated_functional_model(self):
195
196
196
197
return keras .Model (inputs = inputs , outputs = results )
197
198
199
+ def _assert_weights_equal_value (self , annotated_weights , emulated_weights ):
200
+ annotated_weight_values = K .batch_get_value (annotated_weights )
201
+ emulated_weight_values = K .batch_get_value (emulated_weights )
202
+
203
+ self .assertEqual (len (annotated_weight_values ), len (emulated_weight_values ))
204
+ for aw , ew in zip (annotated_weight_values , emulated_weight_values ):
205
+ self .assertAllClose (aw , ew )
206
+
207
+ def _assert_weights_different_objects (
208
+ self , annotated_weights , emulated_weights ):
209
+ self .assertEqual (len (annotated_weights ), len (emulated_weights ))
210
+ for aw , ew in zip (annotated_weights , emulated_weights ):
211
+ self .assertNotEqual (id (aw ), id (ew ))
212
+
198
213
def _assert_layer_emulated (
199
214
self , annotated_layer , emulated_layer , exclude_keys = None ):
200
215
self .assertIsInstance (emulated_layer , QuantizeEmulateWrapper )
@@ -216,6 +231,20 @@ def _assert_layer_emulated(
216
231
217
232
self .assertEqual (annotated_config , emulated_config )
218
233
234
+ def _sort_weights (weights ):
235
+ # Variables are named `quantize_annotate0/kernel:0` and
236
+ # `quantize_emulate0/kernel:0`. Strip layer name to sort.
237
+ return sorted (weights , key = lambda w : w .name .split ('/' )[1 ])
238
+
239
+ annotated_weights = _sort_weights (annotated_layer .trainable_weights )
240
+ emulated_weights = _sort_weights (emulated_layer .trainable_weights )
241
+
242
+ # Quantized model should pick the same weight values from the original
243
+ # model. However, they should not be the same weight objects. We don't
244
+ # want training the quantized model to change weights in the original model.
245
+ self ._assert_weights_different_objects (annotated_weights , emulated_weights )
246
+ self ._assert_weights_equal_value (annotated_weights , emulated_weights )
247
+
219
248
def _assert_model_emulated (
220
249
self , annotated_model , emulated_model , exclude_keys = None ):
221
250
for annotated_layer , emulated_layer in zip (annotated_model .layers ,
0 commit comments