Skip to content

Commit 0b4eb1c

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Create Quantizers specific to Conv/DConv for new quant scheme.
These new quantizers construct min/max vectors based on the shape of the kernels in Conv/DConv, and support multiple scales per-channel as in the new quant scheme. They still need to be plugged into the layers, and also DConv special case handling is not introduced yet. PiperOrigin-RevId: 278964140
1 parent 2adce8a commit 0b4eb1c

File tree

5 files changed

+149
-6
lines changed

5 files changed

+149
-6
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quant_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,21 +99,21 @@ def LastValueQuantize(inputs,
9999

100100
if per_channel:
101101
if input_dim >= 2:
102-
batch_min = math_ops.reduce_min(
102+
batch_min = math_ops.reduce_min_v1(
103103
inputs, reduction_indices=reduce_dims, name='BatchMin')
104104
else:
105105
batch_min = inputs
106106
else:
107-
batch_min = math_ops.reduce_min(inputs, name='BatchMin')
107+
batch_min = math_ops.reduce_min_v1(inputs, name='BatchMin')
108108

109109
if per_channel:
110110
if input_dim >= 2:
111-
batch_max = math_ops.reduce_max(
111+
batch_max = math_ops.reduce_max_v1(
112112
inputs, reduction_indices=reduce_dims, name='BatchMax')
113113
else:
114114
batch_max = inputs
115115
else:
116-
batch_max = math_ops.reduce_max(inputs, name='BatchMax')
116+
batch_max = math_ops.reduce_max_v1(inputs, name='BatchMax')
117117

118118
if symmetric:
119119
if narrow_range:

tensorflow_model_optimization/python/core/quantization/keras/quantizers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def __init__(self, num_bits, per_axis, symmetric, narrow_range):
100100
101101
Args:
102102
num_bits: Number of bits for quantization
103-
per_axis: Whether to apply per_axis quantization.
103+
per_axis: Whether to apply per_axis quantization. The last dimension is
104+
used as the axis.
104105
symmetric: If true, use symmetric quantization limits instead of training
105106
the minimum and maximum of each quantization range separately.
106107
narrow_range: In case of 8 bits, narrow_range nudges the quantized range
@@ -167,7 +168,8 @@ def __init__(self, num_bits, per_axis, symmetric, narrow_range):
167168
168169
Args:
169170
num_bits: Number of bits for quantization
170-
per_axis: Whether to apply per_axis quantization.
171+
per_axis: Whether to apply per_axis quantization. The last dimension is
172+
used as the axis.
171173
symmetric: If true, use symmetric quantization limits instead of training
172174
the minimum and maximum of each quantization range separately.
173175
narrow_range: In case of 8 bits, narrow_range nudges the quantized range

tensorflow_model_optimization/python/core/quantization/keras/tflite/BUILD

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,36 @@ package(default_visibility = [
44

55
licenses(["notice"]) # Apache 2.0
66

7+
py_library(
8+
name = "tflite_quantizers",
9+
srcs = [
10+
"tflite_quantizers.py",
11+
],
12+
srcs_version = "PY2AND3",
13+
visibility = ["//visibility:public"],
14+
deps = [
15+
# tensorflow dep1,
16+
# python/keras tensorflow dep2,
17+
"//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
18+
],
19+
)
20+
21+
py_test(
22+
name = "tflite_quantizers_test",
23+
srcs = [
24+
"tflite_quantizers_test.py",
25+
],
26+
python_version = "PY3",
27+
srcs_version = "PY2AND3",
28+
visibility = ["//visibility:public"],
29+
deps = [
30+
":tflite_quantizers",
31+
# absl/testing:parameterized dep1,
32+
# tensorflow dep1,
33+
# python/keras tensorflow dep2,
34+
],
35+
)
36+
737
py_library(
838
name = "tflite_quantize_registry",
939
srcs = [
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Quantizers specific to TFLite.
16+
17+
Module: tfmot.quantization.keras.tflite
18+
"""
19+
20+
from tensorflow.python.keras import initializers
21+
22+
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
23+
24+
25+
class ConvWeightsQuantizer(quantizers.LastValueQuantizer):
26+
"""Quantizer for handling weights in Conv2D/DepthwiseConv2D layers."""
27+
28+
def __init__(self):
29+
"""Construct LastValueQuantizer with params specific for TFLite Convs."""
30+
31+
super(ConvWeightsQuantizer, self).__init__(
32+
num_bits=8,
33+
per_axis=True,
34+
symmetric=True,
35+
narrow_range=True)
36+
37+
def build(self, tensor_shape, name, layer):
38+
min_weight = layer.add_weight(
39+
name + '_min',
40+
shape=(tensor_shape[-1],),
41+
initializer=initializers.Constant(-6.0),
42+
trainable=False,)
43+
max_weight = layer.add_weight(
44+
name + '_max',
45+
shape=(tensor_shape[-1],),
46+
initializer=initializers.Constant(6.0),
47+
trainable=False)
48+
49+
return [min_weight, max_weight]
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for TFLite Quantizers."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
from absl.testing import parameterized
22+
23+
from tensorflow.python import keras
24+
from tensorflow.python.platform import test
25+
26+
from tensorflow_model_optimization.python.core.quantization.keras.tflite import tflite_quantizers
27+
28+
ConvWeightsQuantizer = tflite_quantizers.ConvWeightsQuantizer
29+
30+
31+
class ConvWeightsQuantizerTest(test.TestCase, parameterized.TestCase):
32+
33+
@parameterized.parameters(
34+
(keras.layers.Conv2D, {
35+
'filters': 5,
36+
'kernel_size': (2, 2)
37+
}),
38+
(keras.layers.DepthwiseConv2D, {
39+
'kernel_size': (2, 2),
40+
'depth_multiplier': 5,
41+
})
42+
)
43+
def testConstructsMinMaxVarsCorrectShape(self, layer_type, kwargs):
44+
quantizer = ConvWeightsQuantizer()
45+
46+
model = keras.Sequential([
47+
layer_type(input_shape=(5, 2, 3), **kwargs)])
48+
layer = model.layers[0]
49+
50+
min_var, max_var = quantizer.build(
51+
layer.weights[0].shape, 'kernel', layer)
52+
# TODO(pulkitb): Add value test to ensure per-axis quantization is
53+
# happening properly. Probably to quant_ops_test.py
54+
quantized_weight = quantizer(layer.weights[0], 0, True, # pylint: disable=unused-variable
55+
**{'min_var': min_var, 'max_var': max_var})
56+
57+
self.assertEqual(5, min_var.shape)
58+
self.assertEqual(5, max_var.shape)
59+
60+
61+
if __name__ == '__main__':
62+
test.main()

0 commit comments

Comments
 (0)