Skip to content

Commit b67c88f

Browse files
Xharktensorflower-gardener
authored andcommitted
Add remove_input_range method that removes input range after apply quantize.
PiperOrigin-RevId: 456004543
1 parent 290793e commit b67c88f

File tree

4 files changed

+103
-2
lines changed

4 files changed

+103
-2
lines changed

tensorflow_model_optimization/python/core/api/quantization/keras/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_annotate_layer
3131
from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_annotate_model
3232
from tensorflow_model_optimization.python.core.quantization.keras.quantize import quantize_apply
33+
from tensorflow_model_optimization.python.core.quantization.keras.quantize import remove_input_range
3334

3435
# quantize with custom quantization parameterization or implementation, or
3536
# handle custom Keras layers.

tensorflow_model_optimization/python/core/quantization/keras/quantize.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,3 +644,58 @@ def _is_functional_model(model):
644644
return (isinstance(model, keras.Model)
645645
and not isinstance(model, keras.Sequential)
646646
and model._is_graph_network) # pylint: disable=protected-access
647+
648+
649+
def remove_input_range(model):
650+
"""Remove the input range.
651+
652+
Example:
653+
654+
```python
655+
model = keras.Sequential([
656+
layers.Dense(10, activation='relu', input_shape=(100,)),
657+
quantize_annotate_layer(layers.Dense(2, activation='sigmoid'))
658+
])
659+
with quantize.quantize_scope():
660+
model = quantize_annotate_model(model)
661+
model = quantize_apply(model)
662+
model = remove_input_range(model)
663+
```
664+
665+
In certain cases, a desired input range is not required if the model itself is
666+
internally used.
667+
668+
Args:
669+
model: A `tf.keras` Sequential or Functional model which has been quantized.
670+
671+
Returns:
672+
Returns a new `tf.keras` model removed input range.
673+
"""
674+
config = model.get_config()
675+
no_input_quantizer = quantizers.NoQuantizer()
676+
serialized_input_quantizer = tf.keras.utils.serialize_keras_object(
677+
no_input_quantizer)
678+
679+
if _is_functional_model(model):
680+
input_layer_list = _nested_to_flatten_node_data_list(config['input_layers'])
681+
for layer_config in config['layers']:
682+
input_name = _unwrap_first_input_name(layer_config['inbound_nodes'])
683+
if input_name is None:
684+
continue
685+
686+
for input_layer in input_layer_list:
687+
if input_name == input_layer[0]:
688+
layer_config['config']['quantizer'] = serialized_input_quantizer
689+
break
690+
691+
model = keras.Model.from_config(config)
692+
else:
693+
if (len(config['layers']) < 1 or
694+
config['layers'][1]['class_name'] != 'QuantizeLayer'):
695+
raise ValueError('`model` should be already quantized.')
696+
config['layers'][1]['config'][
697+
'quantizer'] = serialized_input_quantizer
698+
699+
model = keras.Sequential.from_config(config)
700+
701+
return model

tensorflow_model_optimization/python/core/quantization/keras/quantizers.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,10 +437,52 @@ def __ne__(self, other):
437437
return not self.__eq__(other)
438438

439439

440+
class NoQuantizer(_QuantizeHelper, Quantizer):
441+
"""Dummy quantizer do nothing."""
442+
443+
def __init__(self):
444+
"""Construct an NoQuantizer.
445+
446+
This is an experimental API not subject to backward compatibility.
447+
"""
448+
pass
449+
450+
def build(self, tensor_shape, name, layer):
451+
pass
452+
453+
def __call__(self, inputs, training, weights, **kwargs):
454+
"""Quantize tensor.
455+
456+
Args:
457+
inputs: Input tensor to be quantized.
458+
training: Whether the graph is currently training.
459+
weights: Dictionary of weights the quantizer can use to quantize the
460+
tensor. This contains the weights created in the `build` function.
461+
**kwargs: Additional variables which may be passed to the quantizer.
462+
463+
Returns:
464+
Quantized tensor.
465+
"""
466+
return inputs
467+
468+
def get_config(self):
469+
return {}
470+
471+
def __eq__(self, other):
472+
if not isinstance(other, NoQuantizer):
473+
return False
474+
475+
return True
476+
477+
def __ne__(self, other):
478+
return not self.__eq__(other)
479+
480+
440481
def _types_dict():
441482
return {
442483
'AllValuesQuantizer': AllValuesQuantizer,
443484
'LastValueQuantizer': LastValueQuantizer,
444485
'MovingAverageQuantizer': MovingAverageQuantizer,
445486
'FixedQuantizer': FixedQuantizer,
487+
'NoQuantizer': NoQuantizer,
446488
}

tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,14 @@
3636
quantizers.LastValueQuantizer,
3737
quantizers.MovingAverageQuantizer,
3838
quantizers.AllValuesQuantizer,
39-
quantizers.FixedQuantizer)
39+
quantizers.FixedQuantizer,
40+
quantizers.NoQuantizer)
4041
class QuantizersTest(tf.test.TestCase, parameterized.TestCase):
4142

4243
def _get_quant_params(self, quantizer_type):
43-
if quantizer_type == quantizers.FixedQuantizer:
44+
if quantizer_type == quantizers.NoQuantizer:
45+
return {}
46+
elif quantizer_type == quantizers.FixedQuantizer:
4447
return {
4548
'num_bits': 8,
4649
'init_min': 0.0,

0 commit comments

Comments
 (0)