Skip to content

Commit f148e43

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Use per-axis quantization for Conv2D/DepthwiseConv2D.
PiperOrigin-RevId: 285254892
1 parent 089fadb commit f148e43

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

tensorflow_model_optimization/python/core/quantization/keras/tflite/tflite_quantize_registry.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tensorflow_model_optimization.python.core.quantization.keras import quantize_registry
2828
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
2929
from tensorflow_model_optimization.python.core.quantization.keras.layers import conv_batchnorm
30+
from tensorflow_model_optimization.python.core.quantization.keras.tflite import tflite_quantizers
3031

3132
QuantizeProvider = quantize_provider.QuantizeProvider
3233

@@ -84,7 +85,6 @@ class TFLiteQuantizeRegistry(quantize_registry.QuantizeRegistry, _RNNHelper):
8485

8586
# Convolution Layers
8687
_QuantizeInfo(layers.convolutional.Conv1D, ['kernel'], ['activation']),
87-
_QuantizeInfo(layers.convolutional.Conv2D, ['kernel'], ['activation']),
8888
_QuantizeInfo(layers.convolutional.Conv3D, ['kernel'], ['activation']),
8989
# TODO(pulkitb): Verify Transpose layers.
9090
_QuantizeInfo(layers.convolutional.Conv2DTranspose,
@@ -94,8 +94,6 @@ class TFLiteQuantizeRegistry(quantize_registry.QuantizeRegistry, _RNNHelper):
9494
_no_quantize(layers.convolutional.Cropping1D),
9595
_no_quantize(layers.convolutional.Cropping2D),
9696
_no_quantize(layers.convolutional.Cropping3D),
97-
_QuantizeInfo(layers.convolutional.DepthwiseConv2D,
98-
['depthwise_kernel'], ['activation']),
9997
_no_quantize(layers.convolutional.UpSampling1D),
10098
_no_quantize(layers.convolutional.UpSampling2D),
10199
_no_quantize(layers.convolutional.UpSampling3D),
@@ -172,6 +170,10 @@ def __init__(self):
172170
# Hack for `Activation` layer. That is the only layer with a separate
173171
# QuantizeProvider.
174172
self._layer_quantize_map[layers.Activation] = ActivationQuantizeProvider()
173+
self._layer_quantize_map[layers.Conv2D] = ConvQuantizeProvider(
174+
['kernel'], ['activation'], False)
175+
self._layer_quantize_map[layers.DepthwiseConv2D] = ConvQuantizeProvider(
176+
['depthwise_kernel'], ['activation'], False)
175177

176178
def _is_supported_layer(self, layer):
177179
return layer.__class__ in self._layer_quantize_map
@@ -466,9 +468,20 @@ def get_config(self):
466468
return {}
467469

468470

471+
class ConvQuantizeProvider(TFLiteQuantizeProvider):
472+
"""QuantizeProvider for Conv2D/DepthwiseConv2D layers."""
473+
474+
def __init__(self, weight_attrs, activation_attrs, quantize_output):
475+
super(ConvQuantizeProvider, self).__init__(
476+
weight_attrs, activation_attrs, quantize_output)
477+
478+
self.weight_quantizer = tflite_quantizers.ConvWeightsQuantizer()
479+
480+
469481
def _types_dict():
470482
return {
471483
'TFLiteQuantizeProvider': TFLiteQuantizeProvider,
472484
'TFLiteQuantizeProviderRNN': TFLiteQuantizeProviderRNN,
473-
'ActivationQuantizeProvider': ActivationQuantizeProvider
485+
'ActivationQuantizeProvider': ActivationQuantizeProvider,
486+
'ConvQuantizeProvider': ConvQuantizeProvider
474487
}

tensorflow_model_optimization/python/core/quantization/keras/tflite/tflite_quantizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def build(self, tensor_shape, name, layer):
3939
name + '_min',
4040
shape=(tensor_shape[-1],),
4141
initializer=initializers.Constant(-6.0),
42-
trainable=False,)
42+
trainable=False)
4343
max_weight = layer.add_weight(
4444
name + '_max',
4545
shape=(tensor_shape[-1],),

0 commit comments

Comments
 (0)