Skip to content

Commit baf1bb8

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Implement AllValuesQuantizer
AllValuesQuantizer calculates the range based on the largest and smallest values seen by the Tensor PiperOrigin-RevId: 320088710
1 parent ec636c9 commit baf1bb8

File tree

4 files changed

+176
-1
lines changed

4 files changed

+176
-1
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quant_ops.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,75 @@ def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None):
4444
inputs, min=init_min, max=init_max)
4545

4646

47+
def AllValuesQuantize(inputs,
48+
min_var,
49+
max_var,
50+
name_prefix='AllValuesQuantize',
51+
is_training=True,
52+
num_bits=8,
53+
narrow_range=False,
54+
symmetric=False):
55+
"""Adds a layer that collects quantization ranges as min/max of tensor values.
56+
57+
AllValuesQuantize creates variables called 'min' and 'max',
58+
representing the interval used for quantization and clamping.
59+
60+
Args:
61+
inputs: a tensor containing values to be quantized.
62+
min_var: Variable which stores the min value of tensor.
63+
max_var: Variable which stores the max value of tensor.
64+
name_prefix: name_prefix for created nodes.
65+
is_training: Whether the op is applied to a training or eval graph.
66+
num_bits: Number of bits to use for quantization, must be between 2 and 8.
67+
narrow_range: Whether to use the narrow quantization range
68+
[1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1].
69+
symmetric: If true, use symmetric quantization limits instead of training
70+
the minimum and maximum of each quantization range separately.
71+
Returns:
72+
a tensor containing quantized values.
73+
"""
74+
with tf.name_scope(name_prefix):
75+
if not is_training:
76+
return _FakeQuantWithMinMaxVars(
77+
inputs,
78+
min_var,
79+
max_var,
80+
per_channel=False,
81+
num_bits=num_bits,
82+
narrow_range=narrow_range)
83+
84+
batch_min = tf.math.reduce_min(inputs, name='BatchMin')
85+
batch_max = tf.math.reduce_max(inputs, name='BatchMax')
86+
87+
if symmetric:
88+
if narrow_range:
89+
min_max_ratio = -1
90+
else:
91+
# In two's complement notation, the negative range is slightly larger
92+
# than the positive range.
93+
min_max_ratio = -((1 << num_bits) - 2) / (1 << num_bits)
94+
95+
# TFLite requires that 0.0 is always in the [min; max] range. Because
96+
# batch_min <= batch_max, it follows that range_min <= 0 <= range_max.
97+
batch_min = tf.math.minimum(batch_min, batch_max / min_max_ratio)
98+
batch_max = tf.math.maximum(batch_max, batch_min * min_max_ratio)
99+
100+
# TFLite requires that 0.0 if always in the [min; max] range.
101+
range_min = tf.math.minimum(tf.math.minimum(min_var, batch_min), 0.0)
102+
range_max = tf.math.maximum(tf.math.maximum(max_var, batch_max), 0.0)
103+
104+
assign_min = tf_compat.assign(min_var, range_min, name='AssignMinAllValue')
105+
assign_max = tf_compat.assign(max_var, range_max, name='AssignMaxAllValue')
106+
107+
return _FakeQuantWithMinMaxVars(
108+
inputs,
109+
assign_min,
110+
assign_max,
111+
per_channel=False,
112+
num_bits=num_bits,
113+
narrow_range=narrow_range)
114+
115+
47116
def LastValueQuantize(inputs,
48117
min_var,
49118
max_var,

tensorflow_model_optimization/python/core/quantization/keras/quant_ops_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,31 @@
3232
@keras_parameterized.run_all_keras_modes
3333
class QuantOpsTest(tf.test.TestCase, parameterized.TestCase):
3434

35+
def testAllValuesQuantiize_TrainingAssign(self):
36+
min_value, max_value = self._GetMinMaxValues(
37+
quant_ops.AllValuesQuantize,
38+
[tf.constant([-5.0, 1.0]), tf.constant([-1.0, 5.0])])
39+
40+
self.assertEqual(min_value, -5.0)
41+
self.assertEqual(max_value, 5.0)
42+
43+
def testAllValuesQuantiize_SymmetricTrainingAssign(self):
44+
min_value, max_value = self._GetMinMaxValues(
45+
quant_ops.AllValuesQuantize,
46+
[tf.constant([-_SYMMETRIC_RANGE_RATIO, _SYMMETRIC_RANGE_RATIO])],
47+
symmetric=True,
48+
narrow_range=False)
49+
self.assertEqual(min_value, -1.0)
50+
self.assertEqual(max_value, _SYMMETRIC_RANGE_RATIO)
51+
52+
def testAllValuesQuantiize_SymmetricNarrowRangeTrainingAssign(self):
53+
min_value, max_value = self._GetMinMaxValues(
54+
quant_ops.AllValuesQuantize, [tf.constant([-1, 0.5])],
55+
symmetric=True,
56+
narrow_range=True)
57+
self.assertEqual(min_value, -1.0)
58+
self.assertEqual(max_value, 1)
59+
3560
def testLastValueQuantizeTrainingAssign(self):
3661
min_value, max_value = self._GetMinMaxValues(quant_ops.LastValueQuantize,
3762
[tf.constant([-1.0, 1.0])])

tensorflow_model_optimization/python/core/quantization/keras/quantizers.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,87 @@ def __ne__(self, other):
288288
return not self.__eq__(other)
289289

290290

291+
class AllValuesQuantizer(_QuantizeHelper, Quantizer):
292+
"""Quantize tensor based on min/max of tensor values across all batches."""
293+
294+
def __init__(self, num_bits, per_axis, symmetric, narrow_range):
295+
"""Construct an AllValuesQuantizer.
296+
297+
This is an experimental API not subject to backward compatibility.
298+
299+
Args:
300+
num_bits: Number of bits for quantization
301+
per_axis: Whether to apply per_axis quantization. The last dimension is
302+
used as the axis.
303+
symmetric: If true, use symmetric quantization limits instead of training
304+
the minimum and maximum of each quantization range separately.
305+
narrow_range: In case of 8 bits, narrow_range nudges the quantized range
306+
to be [-127, 127] instead of [-128, 127]. This ensures symmetric
307+
range has 0 as the centre.
308+
"""
309+
self.num_bits = num_bits
310+
self.per_axis = per_axis
311+
self.symmetric = symmetric
312+
self.narrow_range = narrow_range
313+
314+
def build(self, tensor_shape, name, layer):
315+
min_weight = layer.add_weight(
316+
name + '_min',
317+
initializer=keras.initializers.Constant(0.0),
318+
trainable=False)
319+
max_weight = layer.add_weight(
320+
name + '_max',
321+
initializer=keras.initializers.Constant(0.0),
322+
trainable=False)
323+
return {'min_var': min_weight, 'max_var': max_weight}
324+
325+
def __call__(self, inputs, training, weights, **kwargs):
326+
"""Quantize tensor.
327+
328+
Args:
329+
inputs: Input tensor to be quantized.
330+
training: Whether the graph is currently training.
331+
weights: Dictionary of weights the quantizer can use to quantize the
332+
tensor. This contains the weights created in the `build` function.
333+
**kwargs: Additional variables which may be passed to the quantizer.
334+
335+
Returns:
336+
Quantized tensor.
337+
"""
338+
return quant_ops.AllValuesQuantize(
339+
inputs,
340+
weights['min_var'],
341+
weights['max_var'],
342+
is_training=training,
343+
num_bits=self.num_bits,
344+
symmetric=self.symmetric,
345+
narrow_range=self.narrow_range,
346+
)
347+
348+
def get_config(self):
349+
return {
350+
'num_bits': self.num_bits,
351+
'per_axis': self.per_axis,
352+
'symmetric': self.symmetric,
353+
'narrow_range': self.narrow_range
354+
}
355+
356+
def __eq__(self, other):
357+
if not isinstance(other, AllValuesQuantizer):
358+
return False
359+
360+
return (self.num_bits == other.num_bits and
361+
self.per_axis == other.per_axis and
362+
self.symmetric == other.symmetric and
363+
self.narrow_range == other.narrow_range)
364+
365+
def __ne__(self, other):
366+
return not self.__eq__(other)
367+
368+
291369
def _types_dict():
292370
return {
371+
'AllValuesQuantizer': AllValuesQuantizer,
293372
'LastValueQuantizer': LastValueQuantizer,
294373
'MovingAverageQuantizer': MovingAverageQuantizer
295374
}

tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333

3434
@keras_parameterized.run_all_keras_modes
3535
@parameterized.parameters(
36-
quantizers.LastValueQuantizer, quantizers.MovingAverageQuantizer)
36+
quantizers.LastValueQuantizer,
37+
quantizers.MovingAverageQuantizer,
38+
quantizers.AllValuesQuantizer)
3739
class QuantizersTest(tf.test.TestCase, parameterized.TestCase):
3840

3941
def setUp(self):

0 commit comments

Comments
 (0)