Skip to content

Commit eefc66a

Browse files
Xharktensorflower-gardener
authored andcommitted
Add all trainable variable to wrapper and keep the order of the training weights as much as possible for the custom layer.
Note that this CL increased the coverage of keeping weight order, but doesn't guarantee alway same order because it uses _dedup_weights keras method which remove duplicated weights and maintaining order as much as possible. PiperOrigin-RevId: 393035727
1 parent 18e87d2 commit eefc66a

File tree

3 files changed

+3
-69
lines changed

3 files changed

+3
-69
lines changed

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,6 @@ py_strict_test(
266266
":quantize_config",
267267
":quantize_layer",
268268
":quantize_wrapper",
269-
":quantizers",
270269
# numpy dep1,
271270
# tensorflow dep1,
272271
"//tensorflow_model_optimization/python/core/keras:test_utils",

tensorflow_model_optimization/python/core/quantization/keras/quantize_test.py

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from tensorflow_model_optimization.python.core.quantization.keras import quantize_config as quantize_config_mod
2828
from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
2929
from tensorflow_model_optimization.python.core.quantization.keras import quantize_wrapper as quantize_wrapper_mod
30-
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
3130
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
3231

3332
quantize_annotate_layer = quantize.quantize_annotate_layer
@@ -514,69 +513,6 @@ def testQuantizeApply_RunsWhenNestedModelNotAnnotated(self):
514513

515514
quantize_apply(annotated_model)
516515

517-
class CustomConvLayer(tf.keras.layers.Layer):
518-
519-
def __init__(self, name=None, **kwargs):
520-
super().__init__(name=name, **kwargs)
521-
self.conv1 = tf.keras.layers.Conv2D(2, 2)
522-
523-
def build(self, input_shape):
524-
self.conv1.build(input_shape)
525-
526-
def call(self, inputs):
527-
return self.conv1(inputs)
528-
529-
def get_config(self):
530-
return {'name': self.name}
531-
532-
class CustomConvQuantizeConfig(quantize_config_mod.QuantizeConfig):
533-
534-
def get_weights_and_quantizers(self, layer):
535-
return [(layer.conv1.kernel, quantizers.LastValueQuantizer(
536-
num_bits=8, symmetric=True, narrow_range=False, per_axis=False)),]
537-
538-
def get_activations_and_quantizers(self, layer):
539-
return []
540-
541-
def set_quantize_weights(self, layer, quantize_weights):
542-
# layer.conv1._kernel_bak = layer.conv1.kernel
543-
layer.conv1.kernel = quantize_weights[0]
544-
545-
def set_quantize_activations(self, layer, quantize_activations):
546-
pass
547-
548-
def get_output_quantizers(self, layer):
549-
return []
550-
551-
def get_config(self):
552-
return {}
553-
554-
def testQuantizeApply_KeepTrainableWeightOrder(self):
555-
layer = self.CustomConvLayer(input_shape=(28, 28, 3))
556-
model = keras.Sequential([layer])
557-
558-
def apply_quantization_to_dense(layer):
559-
if isinstance(layer, self.CustomConvLayer):
560-
return quantize_annotate_layer(
561-
layer, quantize_config=self.CustomConvQuantizeConfig())
562-
return layer
563-
564-
annotated_model = tf.keras.models.clone_model(
565-
model,
566-
clone_function=apply_quantization_to_dense,
567-
)
568-
569-
with quantize.quantize_scope({
570-
'CustomConvQuantizeConfig': self.CustomConvQuantizeConfig,
571-
'CustomConvLayer': self.CustomConvLayer
572-
}):
573-
quant_aware_model = quantize_apply(annotated_model)
574-
575-
self._assert_weights_different_objects(
576-
model.trainable_weights, quant_aware_model.trainable_weights)
577-
self._assert_weights_equal_value(
578-
model.trainable_weights, quant_aware_model.trainable_weights)
579-
580516

581517
if __name__ == '__main__':
582518
tf.test.main()

tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,15 @@ def build(self, input_shape):
9999
dtype=tf.dtypes.int32,
100100
trainable=False)
101101

102-
self._trainable_weights.extend(self.layer.trainable_weights)
103102
self._weight_vars = []
104103
for weight, quantizer in \
105104
self.quantize_config.get_weights_and_quantizers(self.layer):
106105
quantizer_vars = quantizer.build(weight.shape,
107106
self._weight_name(weight.name), self)
108107

109108
self._weight_vars.append((weight, quantizer, quantizer_vars))
109+
# Needed to ensure unquantized weights get trained as part of the wrapper.
110+
self._trainable_weights.append(weight)
110111

111112
self._quantize_activations = []
112113
for activation, quantizer in \
@@ -214,9 +215,7 @@ def trainable(self, value):
214215

215216
@property
216217
def trainable_weights(self):
217-
# Change the order to keep the weight order after applying QAT.
218-
return self._dedup_weights(
219-
self._trainable_weights + self.layer.trainable_weights)
218+
return self.layer.trainable_weights + self._trainable_weights
220219

221220
@property
222221
def non_trainable_weights(self):

0 commit comments

Comments
 (0)