Skip to content

Commit 15d6ea9

Browse files
Added an argument to support user-provided name prefix
PiperOrigin-RevId: 491514531
1 parent c3c0042 commit 15d6ea9

File tree

3 files changed

+78
-20
lines changed

3 files changed

+78
-20
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quantize.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def quantize_scope(*args):
8080
return tf.keras.utils.custom_object_scope(*(args + (quantization_objects,)))
8181

8282

83-
def quantize_model(to_quantize):
83+
def quantize_model(to_quantize, quantized_layer_name_prefix='quant_'):
8484
"""Quantize a `tf.keras` model with the default quantization implementation.
8585
8686
Quantization constructs a model which emulates quantization during training.
@@ -117,13 +117,18 @@ def quantize_model(to_quantize):
117117
Args:
118118
to_quantize: tf.keras model to be quantized. It can have pre-trained
119119
weights.
120+
quantized_layer_name_prefix: Name prefix for the quantized layers. The
121+
default is `quant_`.
120122
121123
Returns:
122124
Returns a new `tf.keras` model prepared for quantization.
123125
"""
124126
if to_quantize is None:
125127
raise ValueError('`to_quantize` cannot be None')
126128

129+
if quantized_layer_name_prefix is None:
130+
quantized_layer_name_prefix = ''
131+
127132
if not isinstance(to_quantize, keras.Model):
128133
raise ValueError(
129134
'`to_quantize` can only be a `tf.keras.Model` instance. Use '
@@ -138,7 +143,8 @@ def quantize_model(to_quantize):
138143
'Functional model.')
139144

140145
annotated_model = quantize_annotate_model(to_quantize)
141-
return quantize_apply(annotated_model)
146+
return quantize_apply(
147+
annotated_model, quantized_layer_name_prefix=quantized_layer_name_prefix)
142148

143149

144150
def quantize_annotate_model(to_annotate):
@@ -281,7 +287,8 @@ def quantize_annotate_layer(to_annotate, quantize_config=None):
281287
@metrics.MonitorBoolGauge('quantize_apply_usage')
282288
def quantize_apply(
283289
model,
284-
scheme=default_8bit_quantize_scheme.Default8BitQuantizeScheme()):
290+
scheme=default_8bit_quantize_scheme.Default8BitQuantizeScheme(),
291+
quantized_layer_name_prefix='quant_'):
285292
"""Quantize a `tf.keras` model that has been annotated for quantization.
286293
287294
Quantization constructs a model which emulates quantization during training.
@@ -319,6 +326,8 @@ def quantize_apply(
319326
with `quantize_annotate`. It can have pre-trained weights.
320327
scheme: A `QuantizeScheme` which specifies transformer and quantization
321328
registry. The default is `Default8BitQuantizeScheme()`.
329+
quantized_layer_name_prefix: A name prefix for quantized layers. The default
330+
is `quant_`.
322331
323332
Returns:
324333
Returns a new `tf.keras` model in which the annotated layers have been
@@ -327,6 +336,9 @@ def quantize_apply(
327336
if model is None:
328337
raise ValueError('`model` cannot be None')
329338

339+
if quantized_layer_name_prefix is None:
340+
quantized_layer_name_prefix = ''
341+
330342
if not isinstance(model, keras.Model):
331343
raise ValueError('`model` can only be a `tf.keras.Model` instance.'
332344
'You passed an instance of type: {input}.'.format(
@@ -435,7 +447,7 @@ def _quantize(layer): # pylint: disable=missing-docstring
435447
# `QuantizeAnnotate` wrapper may contain `batch_input_shape` like params.
436448
# TODO(pulkitb): Ensure this does not affect model cloning.
437449
return quantize_wrapper.QuantizeWrapperV2(
438-
layer, quantize_config)
450+
layer, quantize_config, name_prefix=quantized_layer_name_prefix)
439451

440452
# 1. Create a copy of the model with the same weights. This ensures
441453
# modifications don't affect the original model, or its weights.
@@ -446,7 +458,7 @@ def _quantize(layer): # pylint: disable=missing-docstring
446458
'Unable to clone model. This generally happens if you used custom '
447459
'Keras layers or objects in your model. Please specify them via '
448460
'`quantize_scope` for your calls to `quantize_model` and '
449-
'`quantize_apply`. [%s].' % er)
461+
'`quantize_apply`. [%s].' % er) from er
450462

451463
# 2. Remove QuantizeAnnotate wrappers from the layers in the model. This
452464
# extracts the original model structure (easier to transform), and

tensorflow_model_optimization/python/core/quantization/keras/quantize_models_test.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import tempfile
2323

2424
from absl.testing import parameterized
25-
2625
import numpy as np
2726
import tensorflow as tf
2827

@@ -108,13 +107,55 @@ def testModelEndToEnd(self, model_type):
108107
model.fit(x_train, y_train)
109108

110109
# 3. Ensure conversion to TFLite works.
111-
_, tflite_file = tempfile.mkstemp('.tflite')
112-
print('TFLite File: ', tflite_file)
113-
with quantize.quantize_scope():
114-
utils.convert_keras_to_tflite(model, tflite_file)
110+
with tempfile.NamedTemporaryFile(suffix='.tflite') as t:
111+
with quantize.quantize_scope():
112+
utils.convert_keras_to_tflite(model, t.name)
113+
114+
# 4. Verify input runs on converted model.
115+
self._verify_tflite(t.name, x_train, y_train)
116+
117+
# Test the model with custom layer name prefix.
118+
@parameterized.product(
119+
model_type=_KERAS_APPLICATION_MODELS,
120+
name_prefix=['', 'custom_prefix_'])
121+
def testModelEndToEndCustomNamePrefix(self, model_type, name_prefix):
122+
# 1. Check whether quantized model graph can be constructed.
123+
model = self._get_model(model_type)
124+
original_layer_names = set([layer.name for layer in model.layers])
125+
126+
model = quantize.quantize_model(
127+
model, quantized_layer_name_prefix=name_prefix)
128+
quantized_layer_names = set([layer.name for layer in model.layers])
129+
130+
# Remove the name of layer which is newly added to quantize the input.
131+
quantized_layer_names.remove('quantize_layer')
132+
133+
if not name_prefix or name_prefix is None:
134+
# The set of layer names should be the same.
135+
self.assertEqual(original_layer_names, quantized_layer_names)
136+
else:
137+
self.assertNotEqual(original_layer_names, quantized_layer_names)
138+
for name in original_layer_names:
139+
if name in quantized_layer_names:
140+
quantized_layer_names.remove(name)
141+
elif name_prefix + name in quantized_layer_names:
142+
quantized_layer_names.remove(name_prefix + name)
143+
144+
self.assertEmpty(quantized_layer_names)
145+
146+
# 2. Sanity check to ensure basic training on random data works.
147+
x_train, y_train = self._create_test_data(model)
148+
model.compile(
149+
loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
150+
model.fit(x_train, y_train)
151+
152+
# 3. Ensure conversion to TFLite works.
153+
with tempfile.NamedTemporaryFile(suffix='.tflite') as t:
154+
with quantize.quantize_scope():
155+
utils.convert_keras_to_tflite(model, t.name)
115156

116-
# 4. Verify input runs on converted model.
117-
self._verify_tflite(tflite_file, x_train, y_train)
157+
# 4. Verify input runs on converted model.
158+
self._verify_tflite(t.name, x_train, y_train)
118159

119160

120161
if __name__ == '__main__':

tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,22 @@
4141
class QuantizeWrapper(tf.keras.layers.Wrapper):
4242
"""Quantizes the weights and activations of the keras layer it wraps."""
4343

44-
def __init__(self, layer, quantize_config, **kwargs):
44+
def __init__(self, layer, quantize_config, name_prefix='quant_', **kwargs):
4545
"""Create a quantize emulate wrapper for a keras layer.
4646
4747
Args:
4848
layer: The keras layer to be quantized.
4949
quantize_config: `QuantizeConfig` to quantize layer.
50+
name_prefix: Prefix for quantized keras layer name. The default is
51+
`quant_`.
5052
**kwargs: Additional keyword arguments to be passed to the keras layer.
5153
"""
5254
if layer is None:
5355
raise ValueError('`layer` cannot be None.')
5456

57+
if name_prefix is None:
58+
name_prefix = ''
59+
5560
# Check against keras.Model since it is an instance of keras.layers.Layer.
5661
if not isinstance(layer, tf.keras.layers.Layer) or isinstance(
5762
layer, tf.keras.Model):
@@ -65,7 +70,7 @@ def __init__(self, layer, quantize_config, **kwargs):
6570
'quantize a layer.')
6671

6772
if 'name' not in kwargs:
68-
kwargs['name'] = self._make_layer_name(layer)
73+
kwargs['name'] = self._make_layer_name(layer, name_prefix)
6974

7075
super(QuantizeWrapper, self).__init__(layer, **kwargs)
7176
self.quantize_config = quantize_config
@@ -74,8 +79,8 @@ def __init__(self, layer, quantize_config, **kwargs):
7479
metrics.MonitorBoolGauge('quantize_wrapper_usage').set(
7580
layer.__class__.__name__)
7681

77-
def _make_layer_name(self, layer):
78-
return '{}_{}'.format('quant', layer.name)
82+
def _make_layer_name(self, layer, name_prefix):
83+
return '{}{}'.format(name_prefix, layer.name)
7984

8085
def _weight_name(self, name):
8186
"""Extracts the weight name from the full TensorFlow variable name.
@@ -100,8 +105,8 @@ def build(self, input_shape):
100105
trainable=False)
101106

102107
self._weight_vars = []
103-
for weight, quantizer in \
104-
self.quantize_config.get_weights_and_quantizers(self.layer):
108+
for weight, quantizer in (
109+
self.quantize_config.get_weights_and_quantizers(self.layer)):
105110
quantizer_vars = quantizer.build(weight.shape,
106111
self._weight_name(weight.name), self)
107112

@@ -110,8 +115,8 @@ def build(self, input_shape):
110115
self._trainable_weights.append(weight)
111116

112117
self._quantize_activations = []
113-
for activation, quantizer in \
114-
self.quantize_config.get_activations_and_quantizers(self.layer):
118+
for activation, quantizer in (
119+
self.quantize_config.get_activations_and_quantizers(self.layer)):
115120
quantize_activation = quantize_aware_activation.QuantizeAwareActivation(
116121
activation, quantizer, self.optimizer_step, self)
117122

0 commit comments

Comments
 (0)