Skip to content

Commit 18e87d2

Browse files
Xharktensorflower-gardener
authored andcommitted
Add all trainable variable to wrapper and keep the order of the training weights as much as possible for the custom layer.
Note that this CL increased the coverage of keeping weight order, but doesn't guarantee alway same order because it uses _dedup_weights keras method which remove duplicated weights and maintaining order as much as possible. PiperOrigin-RevId: 391542287
1 parent d155948 commit 18e87d2

File tree

3 files changed

+69
-3
lines changed

3 files changed

+69
-3
lines changed

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ py_strict_test(
266266
":quantize_config",
267267
":quantize_layer",
268268
":quantize_wrapper",
269+
":quantizers",
269270
# numpy dep1,
270271
# tensorflow dep1,
271272
"//tensorflow_model_optimization/python/core/keras:test_utils",

tensorflow_model_optimization/python/core/quantization/keras/quantize_test.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tensorflow_model_optimization.python.core.quantization.keras import quantize_config as quantize_config_mod
2828
from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
2929
from tensorflow_model_optimization.python.core.quantization.keras import quantize_wrapper as quantize_wrapper_mod
30+
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
3031
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
3132

3233
quantize_annotate_layer = quantize.quantize_annotate_layer
@@ -513,6 +514,69 @@ def testQuantizeApply_RunsWhenNestedModelNotAnnotated(self):
513514

514515
quantize_apply(annotated_model)
515516

517+
class CustomConvLayer(tf.keras.layers.Layer):
518+
519+
def __init__(self, name=None, **kwargs):
520+
super().__init__(name=name, **kwargs)
521+
self.conv1 = tf.keras.layers.Conv2D(2, 2)
522+
523+
def build(self, input_shape):
524+
self.conv1.build(input_shape)
525+
526+
def call(self, inputs):
527+
return self.conv1(inputs)
528+
529+
def get_config(self):
530+
return {'name': self.name}
531+
532+
class CustomConvQuantizeConfig(quantize_config_mod.QuantizeConfig):
533+
534+
def get_weights_and_quantizers(self, layer):
535+
return [(layer.conv1.kernel, quantizers.LastValueQuantizer(
536+
num_bits=8, symmetric=True, narrow_range=False, per_axis=False)),]
537+
538+
def get_activations_and_quantizers(self, layer):
539+
return []
540+
541+
def set_quantize_weights(self, layer, quantize_weights):
542+
# layer.conv1._kernel_bak = layer.conv1.kernel
543+
layer.conv1.kernel = quantize_weights[0]
544+
545+
def set_quantize_activations(self, layer, quantize_activations):
546+
pass
547+
548+
def get_output_quantizers(self, layer):
549+
return []
550+
551+
def get_config(self):
552+
return {}
553+
554+
def testQuantizeApply_KeepTrainableWeightOrder(self):
555+
layer = self.CustomConvLayer(input_shape=(28, 28, 3))
556+
model = keras.Sequential([layer])
557+
558+
def apply_quantization_to_dense(layer):
559+
if isinstance(layer, self.CustomConvLayer):
560+
return quantize_annotate_layer(
561+
layer, quantize_config=self.CustomConvQuantizeConfig())
562+
return layer
563+
564+
annotated_model = tf.keras.models.clone_model(
565+
model,
566+
clone_function=apply_quantization_to_dense,
567+
)
568+
569+
with quantize.quantize_scope({
570+
'CustomConvQuantizeConfig': self.CustomConvQuantizeConfig,
571+
'CustomConvLayer': self.CustomConvLayer
572+
}):
573+
quant_aware_model = quantize_apply(annotated_model)
574+
575+
self._assert_weights_different_objects(
576+
model.trainable_weights, quant_aware_model.trainable_weights)
577+
self._assert_weights_equal_value(
578+
model.trainable_weights, quant_aware_model.trainable_weights)
579+
516580

517581
if __name__ == '__main__':
518582
tf.test.main()

tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,14 @@ def build(self, input_shape):
9999
dtype=tf.dtypes.int32,
100100
trainable=False)
101101

102+
self._trainable_weights.extend(self.layer.trainable_weights)
102103
self._weight_vars = []
103104
for weight, quantizer in \
104105
self.quantize_config.get_weights_and_quantizers(self.layer):
105106
quantizer_vars = quantizer.build(weight.shape,
106107
self._weight_name(weight.name), self)
107108

108109
self._weight_vars.append((weight, quantizer, quantizer_vars))
109-
# Needed to ensure unquantized weights get trained as part of the wrapper.
110-
self._trainable_weights.append(weight)
111110

112111
self._quantize_activations = []
113112
for activation, quantizer in \
@@ -215,7 +214,9 @@ def trainable(self, value):
215214

216215
@property
217216
def trainable_weights(self):
218-
return self.layer.trainable_weights + self._trainable_weights
217+
# Change the order to keep the weight order after applying QAT.
218+
return self._dedup_weights(
219+
self._trainable_weights + self.layer.trainable_weights)
219220

220221
@property
221222
def non_trainable_weights(self):

0 commit comments

Comments
 (0)