Skip to content

Commit 4e206a5

Browse files
Xharktensorflower-gardener
authored andcommitted
Add fix_input_output_range method that fix the input and output range after apply quantize.
PiperOrigin-RevId: 427484618
1 parent 6e8584f commit 4e206a5

File tree

9 files changed

+404
-15
lines changed

9 files changed

+404
-15
lines changed

tensorflow_model_optimization/python/core/api/quantization/keras/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_model
2727

2828
# quantize some layers with default or custom quantization implementation.
29+
from tensorflow_model_optimization.python.core.quantization.keras.quantize import fix_input_output_range
2930
from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_annotate_layer
3031
from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_annotate_model
3132
from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_apply

tensorflow_model_optimization/python/core/api/quantization/keras/quantizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# quantize with custom quantization parameterization or implementation, or
1818
# handle custom Keras layers.
1919
from tensorflow_model_optimization.python.core.quantization.keras.quantizers import AllValuesQuantizer
20+
from tensorflow_model_optimization.python.core.quantization.keras.quantizers import FixedQuantizer
2021
from tensorflow_model_optimization.python.core.quantization.keras.quantizers import LastValueQuantizer
2122
from tensorflow_model_optimization.python.core.quantization.keras.quantizers import MovingAverageQuantizer
2223
from tensorflow_model_optimization.python.core.quantization.keras.quantizers import Quantizer

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ py_strict_library(
9090
visibility = ["//visibility:public"],
9191
deps = [
9292
# six dep1,
93+
# tensorflow dep1,
94+
"//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
9395
],
9496
)
9597

@@ -270,6 +272,7 @@ py_strict_test(
270272
":quantize_layer",
271273
":quantize_wrapper",
272274
":quantizers",
275+
# absl/testing:parameterized dep1,
273276
# numpy dep1,
274277
# tensorflow dep1,
275278
"//tensorflow_model_optimization/python/core/keras:test_utils",

tensorflow_model_optimization/python/core/quantization/keras/quant_ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,17 @@
2525
from tensorflow_model_optimization.python.core.keras import compat as tf_compat
2626

2727

28-
def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None):
28+
def FixedQuantize(
29+
inputs, init_min=-6.0, init_max=6.0, scope=None, narrow_range=False):
2930
"""Adds a fake quantize layer with fixed quantization interval.
3031
3132
Args:
3233
inputs: a tensor containing values to be quantized.
3334
init_min: the lower end of quantization interval.
3435
init_max: the upper end of quantization interval.
3536
scope: Optional scope for name_scope.
37+
narrow_range: Whether to use the narrow quantization range
38+
[1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1].
3639
Returns:
3740
a tensor containing quantized values.
3841
"""
@@ -41,7 +44,7 @@ def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None):
4144

4245
with tf.name_scope(scope):
4346
return tf.quantization.fake_quant_with_min_max_args(
44-
inputs, min=init_min, max=init_max)
47+
inputs, min=init_min, max=init_max, narrow_range=narrow_range)
4548

4649

4750
def AllValuesQuantize(inputs,

tensorflow_model_optimization/python/core/quantization/keras/quantize.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def quantize_scope(*args):
7171
'QuantizeWrapperV2': quantize_wrapper.QuantizeWrapperV2,
7272
'QuantizeLayer': quantize_layer.QuantizeLayer,
7373
'OutputOnlyConfig': quantize_config_mod.OutputOnlyConfig,
74+
'FixedQuantizeConfig': quantize_config_mod.FixedQuantizeConfig,
7475
}
7576
quantization_objects.update(default_8bit_quantize_registry._types_dict()) # pylint: disable=protected-access
7677
quantization_objects.update(default_n_bit_quantize_registry._types_dict()) # pylint: disable=protected-access
@@ -472,3 +473,169 @@ def _quantize(layer): # pylint: disable=missing-docstring
472473

473474
return keras.models.clone_model(
474475
transformed_model, input_tensors=None, clone_function=_quantize)
476+
477+
478+
def _unwrap_first_input_name(inbound_nodes):
479+
"""Unwrap inbound_nodes three times to get first input name.
480+
481+
Args:
482+
inbound_nodes: A str config that indicates input node. This method assumed
483+
the inbound_nodes looks like `[[['input', 0, 0, {}]]]`.
484+
485+
Returns:
486+
Returns a str name for the first inbound node.
487+
"""
488+
current = inbound_nodes
489+
490+
for _ in range(3):
491+
if not current:
492+
return None
493+
if not isinstance(current, list):
494+
return None
495+
current = current[0]
496+
497+
if isinstance(current, str):
498+
return current
499+
500+
return None
501+
502+
503+
def _wrap_fixed_range(
504+
quantize_config, num_bits, init_min, init_max, narrow_range):
505+
config = quantize_config_mod.FixedQuantizeConfig.from_config(
506+
{'config': quantize_config,
507+
'num_bits': num_bits,
508+
'init_min': init_min,
509+
'init_max': init_max,
510+
'narrow_range': narrow_range})
511+
return tf.keras.utils.serialize_keras_object(config)
512+
513+
514+
def _is_serialized_node_data(nested):
515+
# Node data can be of form `[layer_name, node_id, tensor_id]` or
516+
# `[layer_name, node_id, tensor_id, kwargs]`.
517+
if (isinstance(nested, list) and (len(nested) in [3, 4]) and
518+
isinstance(nested[0], str)):
519+
return True
520+
return False
521+
522+
523+
def _nested_to_flatten_node_data_list(nested):
524+
"""Makes nested node data to flatten node data list."""
525+
if _is_serialized_node_data(nested):
526+
return [nested]
527+
528+
if isinstance(nested, list):
529+
return sum(map(_nested_to_flatten_node_data_list, nested), [])
530+
531+
if isinstance(nested, dict):
532+
return sum(map(_nested_to_flatten_node_data_list, nested.values()), [])
533+
534+
raise ValueError('{} is not a supported nested node data.'.format(nested))
535+
536+
537+
def fix_input_output_range(
538+
model,
539+
num_bits=8,
540+
input_min=0.0,
541+
input_max=1.0,
542+
output_min=0.0,
543+
output_max=1.0,
544+
narrow_range=False):
545+
"""Fix the input and output ranges.
546+
547+
Example:
548+
549+
```python
550+
model = keras.Sequential([
551+
layers.Dense(10, activation='relu', input_shape=(100,)),
552+
quantize_annotate_layer(layers.Dense(2, activation='sigmoid'))
553+
])
554+
with quantize.quantize_scope():
555+
model = quantize_annotate_model(model)
556+
model = quantize_apply(model)
557+
model = fix_input_output_range(model, num_bits=4,
558+
input_min=0, input_max=15,
559+
output_min=0, output_max=15,
560+
narrow_range=False)
561+
```
562+
563+
In certain cases, a desired input/output ranges is known and should not be
564+
altered during training. To set these values, use the arguments as follows:
565+
566+
Args:
567+
model: A `tf.keras` Sequential or Functional model which has been quantized.
568+
num_bits: Number of bits for quantization
569+
input_min: The lower end of quantization interval for the input.
570+
input_max: The upper end of quantization interval for the input.
571+
output_min: The lower end of quantization interval for the output.
572+
output_max: The upper end of quantization interval for the output.
573+
narrow_range: In case of 8 bits, narrow_range nudges the quantized range
574+
to be [-127, 127] instead of [-128, 127]. This ensures symmetric
575+
range has 0 as the centre.
576+
577+
Returns:
578+
Returns a new `tf.keras` model fixed input range set to (input_min,
579+
input_max) and fixed output range set to (output_min, output_max).
580+
"""
581+
config = model.get_config()
582+
fixed_input_quantizer = quantizers.FixedQuantizer(
583+
num_bits=num_bits,
584+
init_min=input_min,
585+
init_max=input_max,
586+
narrow_range=narrow_range)
587+
serialized_fixed_input_quantizer = tf.keras.utils.serialize_keras_object(
588+
fixed_input_quantizer)
589+
590+
if _is_functional_model(model):
591+
input_layer_list = _nested_to_flatten_node_data_list(config['input_layers'])
592+
for layer_config in config['layers']:
593+
input_name = _unwrap_first_input_name(layer_config['inbound_nodes'])
594+
if input_name is None:
595+
continue
596+
597+
for input_layer in input_layer_list:
598+
if input_name == input_layer[0]:
599+
layer_config['config']['quantizer'] = serialized_fixed_input_quantizer
600+
break
601+
602+
output_layer_list = _nested_to_flatten_node_data_list(
603+
config['output_layers'])
604+
for layer_config in config['layers']:
605+
for output_layer in output_layer_list:
606+
if layer_config['config']['name'] == output_layer[0]:
607+
if 'quantize_config' in layer_config['config']:
608+
layer_config['config']['quantize_config'] = (
609+
_wrap_fixed_range(
610+
layer_config['config']['quantize_config'],
611+
num_bits=num_bits,
612+
init_min=output_min,
613+
init_max=output_max,
614+
narrow_range=narrow_range))
615+
break
616+
617+
model = keras.Model.from_config(config)
618+
else:
619+
if (len(config['layers']) < 1 or
620+
config['layers'][1]['class_name'] != 'QuantizeLayer'):
621+
raise ValueError('`model` should be already quantized.')
622+
config['layers'][1]['config'][
623+
'quantizer'] = serialized_fixed_input_quantizer
624+
if 'quantize_config' in config['layers'][-1]['config']:
625+
config['layers'][-1]['config']['quantize_config'] = (
626+
_wrap_fixed_range(
627+
config['layers'][-1]['config']['quantize_config'],
628+
num_bits=num_bits,
629+
init_min=output_min,
630+
init_max=output_max,
631+
narrow_range=narrow_range))
632+
633+
model = keras.Sequential.from_config(config)
634+
635+
return model
636+
637+
638+
def _is_functional_model(model):
639+
return (isinstance(model, keras.Model)
640+
and not isinstance(model, keras.Sequential)
641+
and model._is_graph_network) # pylint: disable=protected-access

tensorflow_model_optimization/python/core/quantization/keras/quantize_config.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
import abc
1818
import six
1919

20+
import tensorflow as tf
21+
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
22+
2023

2124
@six.add_metaclass(abc.ABCMeta)
2225
class QuantizeConfig(object):
@@ -215,3 +218,54 @@ def get_config(self):
215218
@classmethod
216219
def from_config(cls, config):
217220
return cls(**config)
221+
222+
223+
class FixedQuantizeConfig(QuantizeConfig):
224+
"""QuantizeConfig that quantizes output with fixed range."""
225+
226+
def __init__(self, config, num_bits, init_min, init_max, narrow_range):
227+
self.config = config
228+
self.num_bits = num_bits
229+
self.init_min = init_min
230+
self.init_max = init_max
231+
self.narrow_range = narrow_range
232+
self.fixed_quantizer = quantizers.FixedQuantizer(
233+
num_bits=num_bits,
234+
init_min=init_min,
235+
init_max=init_max,
236+
narrow_range=narrow_range)
237+
238+
def get_weights_and_quantizers(self, layer):
239+
return self.config.get_weights_and_quantizers(layer)
240+
241+
def set_quantize_weights(self, layer, quantize_weights):
242+
return self.config.set_quantize_weights(layer, quantize_weights)
243+
244+
def get_activations_and_quantizers(self, layer):
245+
activations_and_quantizers = (
246+
self.config.get_activations_and_quantizers(layer))
247+
return [(activation, self.fixed_quantizer)
248+
for activation, _ in activations_and_quantizers]
249+
250+
def set_quantize_activations(self, layer, quantize_activations):
251+
return self.config.set_quantize_activations(
252+
layer, quantize_activations)
253+
254+
def get_output_quantizers(self, layer):
255+
outputs_and_quantizers = (
256+
self.config.get_output_quantizers(layer))
257+
return [self.fixed_quantizer
258+
for _ in outputs_and_quantizers]
259+
260+
def get_config(self):
261+
return {
262+
'config': tf.keras.utils.serialize_keras_object(self.config),
263+
'num_bits': self.num_bits,
264+
'init_min': self.init_min,
265+
'init_max': self.init_max,
266+
'narrow_range': self.narrow_range}
267+
268+
@classmethod
269+
def from_config(cls, config):
270+
config['config'] = tf.keras.utils.deserialize_keras_object(config['config'])
271+
return cls(**config)

0 commit comments

Comments
 (0)