Skip to content

Commit 1606dc9

Browse files
daverimtensorflower-gardener
authored andcommitted
Update Concat transforms to handle QuantizeLayer inputs by removing quantizer. Allow quantize layer to take None quantizers.
PiperOrigin-RevId: 406065456
1 parent 5e494aa commit 1606dc9

File tree

4 files changed

+45
-5
lines changed

4 files changed

+45
-5
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,8 @@ def pattern(self):
504504
'Concatenate', inputs=[LayerPattern('.*'), LayerPattern('.*')])
505505

506506
def _get_layer_type(self, layer_class_name):
507+
if layer_class_name == 'QuantizeLayer':
508+
return quantize_layer.QuantizeLayer
507509
keras_layers = inspect.getmembers(tf.keras.layers, inspect.isclass)
508510
for layer_name, layer_type in keras_layers:
509511
if layer_name == layer_class_name:
@@ -537,6 +539,10 @@ def replacement(self, match_layer):
537539
default_8bit_quantize_configs.NoOpQuantizeConfig())
538540
continue
539541

542+
if layer_class == quantize_layer.QuantizeLayer:
543+
feed_layer_node.metadata['quantizer'] = None
544+
continue
545+
540546
if not default_registry._is_supported_layer(layer_class):
541547
# Feeding layer is not supported by registry
542548
return match_layer

tensorflow_model_optimization/python/core/quantization/keras/quantize.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,18 @@ def _quantize(layer): # pylint: disable=missing-docstring
375375
# It supports for custom QuantizeWrapper.
376376
return layer
377377

378+
# layer is a QuantizeLayer, possibly rebuild
379+
# layer with modified config from parameters stored in the map.
380+
if isinstance(layer, quantize_layer.QuantizeLayer):
381+
# If there is more than one usage of the input, even if all are concat,
382+
# we need to quantize.
383+
if len(layer._outbound_nodes) > 1: # pylint: disable=protected-access
384+
return layer
385+
layer_config = layer.get_config()
386+
for key, value in layer_quantize_map[layer.name].items():
387+
layer_config[key] = value
388+
return quantize_layer.QuantizeLayer.from_config(layer_config)
389+
378390
if layer.name in requires_output_quantize:
379391
if not quantize_registry.supports(layer):
380392
return layer

tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,23 @@ def __init__(self, quantizer, **kwargs):
3838
"""Create a QuantizeLayer.
3939
4040
Args:
41-
quantizer: `Quantizer` used to quantize tensors.
41+
quantizer: `Quantizer` used to quantize tensors. quantizer=None
42+
means no quantization of the input layer.
4243
**kwargs: Additional keyword arguments to be passed to the keras layer.
4344
"""
4445
super(QuantizeLayer, self).__init__(**kwargs)
4546

46-
if quantizer is None or not isinstance(quantizer, quantizers.Quantizer):
47-
raise ValueError('quantizer should not be None, and should be an instance'
47+
if quantizer is not None and not isinstance(quantizer,
48+
quantizers.Quantizer):
49+
raise ValueError('quantizer should be an instance'
4850
'of `tfmot.quantization.keras.quantizers.Quantizer`.')
4951

5052
self.quantizer = quantizer
5153

5254
def build(self, input_shape):
53-
self.quantizer_vars = self.quantizer.build(
54-
input_shape, self.name, self)
55+
if self.quantizer:
56+
self.quantizer_vars = self.quantizer.build(
57+
input_shape, self.name, self)
5558

5659
self.optimizer_step = self.add_weight(
5760
'optimizer_step',
@@ -60,6 +63,9 @@ def build(self, input_shape):
6063
trainable=False)
6164

6265
def call(self, inputs, training=None):
66+
if not self.quantizer:
67+
return inputs
68+
6369
if training is None:
6470
training = tf.keras.backend.learning_phase()
6571

tensorflow_model_optimization/python/core/quantization/keras/quantize_layer_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,22 @@ def testSerializationQuantizeLayer(self):
6969

7070
self.assertEqual(layer_from_config.get_config(), layer.get_config())
7171

72+
def testNoQuantizeLayer(self):
73+
layer = QuantizeLayer(quantizer=None, input_shape=(4,))
74+
model = tf.keras.Sequential([layer])
75+
x = np.random.rand(1, 4)
76+
self.assertAllClose(x, model.predict(x))
77+
78+
custom_objects = {
79+
'QuantizeLayer': QuantizeLayer,
80+
}
81+
82+
serialized_layer = serialize_layer(layer)
83+
with tf.keras.utils.custom_object_scope(custom_objects):
84+
layer_from_config = deserialize_layer(serialized_layer)
85+
86+
self.assertEqual(layer_from_config.get_config(), layer.get_config())
87+
7288

7389
if __name__ == '__main__':
7490
tf.test.main()

0 commit comments

Comments
 (0)