Skip to content

Commit 3b55bba

Browse files
Xharktensorflower-gardener
authored andcommitted
Add QuantizeWrapperV2 that preserve weights order, and make it as a default for quantize_apply.
PiperOrigin-RevId: 396525480
1 parent f676a3b commit 3b55bba

File tree

5 files changed

+84
-2
lines changed

5 files changed

+84
-2
lines changed

tensorflow_model_optimization/python/core/api/quantization/keras/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
# handle custom Keras layers.
3535
from tensorflow_model_optimization.python.core.quantization.keras.quantize_config import QuantizeConfig
3636
from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapper
37-
37+
from tensorflow_model_optimization.python.core.quantization.keras.quantize_wrapper import QuantizeWrapperV2
3838
# Deserialize quantized model for Keras h5 format.
3939
from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_scope
4040

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ py_strict_test(
269269
":quantize_config",
270270
":quantize_layer",
271271
":quantize_wrapper",
272+
":quantizers",
272273
# numpy dep1,
273274
# tensorflow dep1,
274275
"//tensorflow_model_optimization/python/core/keras:test_utils",

tensorflow_model_optimization/python/core/quantization/keras/quantize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def quantize_scope(*args):
6767
quantize_aware_activation.QuantizeAwareActivation,
6868
'NoOpActivation': quantize_aware_activation.NoOpActivation,
6969
'QuantizeWrapper': quantize_wrapper.QuantizeWrapper,
70+
'QuantizeWrapperV2': quantize_wrapper.QuantizeWrapperV2,
7071
'QuantizeLayer': quantize_layer.QuantizeLayer,
7172
'OutputOnlyConfig': quantize_config_mod.OutputOnlyConfig,
7273
}
@@ -401,7 +402,8 @@ def _quantize(layer): # pylint: disable=missing-docstring
401402
# `QuantizeAnnotate`. This should generally be fine, but occasionally
402403
# `QuantizeAnnotate` wrapper may contain `batch_input_shape` like params.
403404
# TODO(pulkitb): Ensure this does not affect model cloning.
404-
return quantize_wrapper.QuantizeWrapper(layer, quantize_config)
405+
return quantize_wrapper.QuantizeWrapperV2(
406+
layer, quantize_config)
405407

406408
# 1. Create a copy of the model with the same weights. This ensures
407409
# modifications don't affect the original model, or its weights.

tensorflow_model_optimization/python/core/quantization/keras/quantize_test.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tensorflow_model_optimization.python.core.quantization.keras import quantize_config as quantize_config_mod
2828
from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
2929
from tensorflow_model_optimization.python.core.quantization.keras import quantize_wrapper as quantize_wrapper_mod
30+
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
3031
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
3132

3233
quantize_annotate_layer = quantize.quantize_annotate_layer
@@ -513,6 +514,68 @@ def testQuantizeApply_RunsWhenNestedModelNotAnnotated(self):
513514

514515
quantize_apply(annotated_model)
515516

517+
class CustomConvLayer(tf.keras.layers.Layer):
518+
519+
def __init__(self, name=None, **kwargs):
520+
super().__init__(name=name, **kwargs)
521+
self.conv1 = tf.keras.layers.Conv2D(2, 2)
522+
523+
def build(self, input_shape):
524+
self.conv1.build(input_shape)
525+
526+
def call(self, inputs):
527+
return self.conv1(inputs)
528+
529+
def get_config(self):
530+
return {'name': self.name}
531+
532+
class CustomConvQuantizeConfig(quantize_config_mod.QuantizeConfig):
533+
534+
def get_weights_and_quantizers(self, layer):
535+
return [(layer.conv1.kernel, quantizers.LastValueQuantizer(
536+
num_bits=8, symmetric=True, narrow_range=False, per_axis=False)),]
537+
538+
def get_activations_and_quantizers(self, layer):
539+
return []
540+
541+
def set_quantize_weights(self, layer, quantize_weights):
542+
# layer.conv1._kernel_bak = layer.conv1.kernel
543+
layer.conv1.kernel = quantize_weights[0]
544+
545+
def set_quantize_activations(self, layer, quantize_activations):
546+
pass
547+
548+
def get_output_quantizers(self, layer):
549+
return []
550+
551+
def get_config(self):
552+
return {}
553+
554+
def testQuantizeApply_KeepTrainableWeightOrder(self):
555+
layer = self.CustomConvLayer(input_shape=(28, 28, 3))
556+
model = keras.Sequential([layer])
557+
558+
def apply_quantization_to_dense(layer):
559+
if isinstance(layer, self.CustomConvLayer):
560+
return quantize_annotate_layer(
561+
layer, quantize_config=self.CustomConvQuantizeConfig())
562+
return layer
563+
564+
annotated_model = tf.keras.models.clone_model(
565+
model,
566+
clone_function=apply_quantization_to_dense,
567+
)
568+
569+
with quantize.quantize_scope({
570+
'CustomConvQuantizeConfig': self.CustomConvQuantizeConfig,
571+
'CustomConvLayer': self.CustomConvLayer
572+
}):
573+
quant_aware_model = quantize_apply(annotated_model)
574+
575+
self._assert_weights_different_objects(
576+
model.trainable_weights, quant_aware_model.trainable_weights)
577+
self._assert_weights_equal_value(
578+
model.trainable_weights, quant_aware_model.trainable_weights)
516579

517580
if __name__ == '__main__':
518581
tf.test.main()

tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,19 @@ def updates(self):
228228
@property
229229
def losses(self):
230230
return self.layer.losses + self._losses
231+
232+
233+
# TODO(b/199809494): Update guide document to use QuantizeWrapperV2.
234+
# Do not override this class method to quantize wrapper directly.
235+
# It breaks existing h5 models that uses QuantizeWrapper class.
236+
class QuantizeWrapperV2(QuantizeWrapper):
237+
238+
def build(self, input_shape):
239+
self._trainable_weights.extend(self.layer.trainable_weights)
240+
super(QuantizeWrapperV2, self).build(input_shape)
241+
242+
@property
243+
def trainable_weights(self):
244+
# Change the order to keep the weight order after applying QAT.
245+
return self._dedup_weights(
246+
self._trainable_weights + self.layer.trainable_weights)

0 commit comments

Comments
 (0)