Skip to content

Commit 9d0ca18

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Clone weights of layers before quantizing.
Ensures that variables/layers of models which are cloned carry over weights so training can proceed on the original layers. Tests verify this as quantization is applied on annotated layers. PiperOrigin-RevId: 255691355
1 parent 952fafd commit 9d0ca18

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quantize_emulate.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,11 @@ def quantize_apply(model):
172172
'annotated with `quantize_annotate`. There are no layers '
173173
'to quantize.')
174174

175-
def _clone_layer(layer):
176-
return layer.__class__.from_config(layer.get_config())
175+
def _clone_model_with_weights(model_to_clone):
176+
cloned_model = keras.models.clone_model(model_to_clone)
177+
cloned_model.set_weights(model_to_clone.get_weights())
178+
179+
return cloned_model
177180

178181
def _quantize_activation(activation, parent_class, quantize_params):
179182
try:
@@ -194,10 +197,13 @@ def _get_quantize_activation_params(layer):
194197
return quant_params
195198

196199
def _apply_quantization(quant_annotate_layer):
197-
layer_to_quantize = _clone_layer(quant_annotate_layer.layer)
198-
quantize_params = quant_annotate_layer.get_quantize_params()
200+
return QuantizeEmulateWrapper(
201+
quant_annotate_layer.layer,
202+
**(quant_annotate_layer.get_quantize_params()))
199203

200-
return QuantizeEmulateWrapper(layer_to_quantize, **quantize_params)
204+
# Create a copy of the model with the same weights. We can then quantize this
205+
# model without modifying the weights of the original model.
206+
model_copy = _clone_model_with_weights(model)
201207

202208
# Apply all graph level transformations.
203209
replace_map = {}
@@ -206,7 +212,7 @@ def _apply_quantization(quant_annotate_layer):
206212
# Dense(activation='relu') -> Dense(activation=QuantAwareActivation('relu'))
207213
# TODO(pulkitb): Not all layers (LSTMs) have just activation. Add
208214
# generic handling for all layers.
209-
for layer in model.layers:
215+
for layer in model_copy.layers:
210216
if isinstance(layer, quant_annotate.QuantizeAnnotate) and \
211217
(layer.layer.activation is not None and
212218
layer.layer.activation != keras.activations.linear):
@@ -225,13 +231,12 @@ def _add_quant_emulate_wrapper(layer): # pylint: disable=missing-docstring
225231
if layer in replace_map:
226232
return replace_map[layer]
227233

228-
# No need to quantize layer. Simply clone and return.
229234
if not isinstance(layer, quant_annotate.QuantizeAnnotate):
230-
return _clone_layer(layer)
235+
return layer
231236

232237
# Use QuantizeEmulate wrapper on annotated layer which actually
233238
# quantization ops.
234239
return _apply_quantization(layer)
235240

236241
return keras.models.clone_model(
237-
model, input_tensors=None, clone_function=_add_quant_emulate_wrapper)
242+
model_copy, input_tensors=None, clone_function=_add_quant_emulate_wrapper)

tensorflow_model_optimization/python/core/quantization/keras/quantize_emulate_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import numpy as np
2222

2323
from tensorflow.python import keras
24+
from tensorflow.python.keras import backend as K
2425
from tensorflow.python.platform import test
2526
from tensorflow_model_optimization.python.core.quantization.keras import quantize_annotate as quant_annotate
2627
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
@@ -195,6 +196,20 @@ def _get_annotated_functional_model(self):
195196

196197
return keras.Model(inputs=inputs, outputs=results)
197198

199+
def _assert_weights_equal_value(self, annotated_weights, emulated_weights):
200+
annotated_weight_values = K.batch_get_value(annotated_weights)
201+
emulated_weight_values = K.batch_get_value(emulated_weights)
202+
203+
self.assertEqual(len(annotated_weight_values), len(emulated_weight_values))
204+
for aw, ew in zip(annotated_weight_values, emulated_weight_values):
205+
self.assertAllClose(aw, ew)
206+
207+
def _assert_weights_different_objects(
208+
self, annotated_weights, emulated_weights):
209+
self.assertEqual(len(annotated_weights), len(emulated_weights))
210+
for aw, ew in zip(annotated_weights, emulated_weights):
211+
self.assertNotEqual(id(aw), id(ew))
212+
198213
def _assert_layer_emulated(
199214
self, annotated_layer, emulated_layer, exclude_keys=None):
200215
self.assertIsInstance(emulated_layer, QuantizeEmulateWrapper)
@@ -216,6 +231,20 @@ def _assert_layer_emulated(
216231

217232
self.assertEqual(annotated_config, emulated_config)
218233

234+
def _sort_weights(weights):
235+
# Variables are named `quantize_annotate0/kernel:0` and
236+
# `quantize_emulate0/kernel:0`. Strip layer name to sort.
237+
return sorted(weights, key=lambda w: w.name.split('/')[1])
238+
239+
annotated_weights = _sort_weights(annotated_layer.trainable_weights)
240+
emulated_weights = _sort_weights(emulated_layer.trainable_weights)
241+
242+
# Quantized model should pick the same weight values from the original
243+
# model. However, they should not be the same weight objects. We don't
244+
# want training the quantized model to change weights in the original model.
245+
self._assert_weights_different_objects(annotated_weights, emulated_weights)
246+
self._assert_weights_equal_value(annotated_weights, emulated_weights)
247+
219248
def _assert_model_emulated(
220249
self, annotated_model, emulated_model, exclude_keys=None):
221250
for annotated_layer, emulated_layer in zip(annotated_model.layers,

0 commit comments

Comments
 (0)