@@ -80,7 +80,7 @@ def quantize_scope(*args):
80
80
return tf .keras .utils .custom_object_scope (* (args + (quantization_objects ,)))
81
81
82
82
83
- def quantize_model (to_quantize ):
83
+ def quantize_model (to_quantize , quantized_layer_name_prefix = 'quant_' ):
84
84
"""Quantize a `tf.keras` model with the default quantization implementation.
85
85
86
86
Quantization constructs a model which emulates quantization during training.
@@ -117,13 +117,18 @@ def quantize_model(to_quantize):
117
117
Args:
118
118
to_quantize: tf.keras model to be quantized. It can have pre-trained
119
119
weights.
120
+ quantized_layer_name_prefix: Name prefix for the quantized layers. The
121
+ default is `quant_`.
120
122
121
123
Returns:
122
124
Returns a new `tf.keras` model prepared for quantization.
123
125
"""
124
126
if to_quantize is None :
125
127
raise ValueError ('`to_quantize` cannot be None' )
126
128
129
+ if quantized_layer_name_prefix is None :
130
+ quantized_layer_name_prefix = ''
131
+
127
132
if not isinstance (to_quantize , keras .Model ):
128
133
raise ValueError (
129
134
'`to_quantize` can only be a `tf.keras.Model` instance. Use '
@@ -138,7 +143,8 @@ def quantize_model(to_quantize):
138
143
'Functional model.' )
139
144
140
145
annotated_model = quantize_annotate_model (to_quantize )
141
- return quantize_apply (annotated_model )
146
+ return quantize_apply (
147
+ annotated_model , quantized_layer_name_prefix = quantized_layer_name_prefix )
142
148
143
149
144
150
def quantize_annotate_model (to_annotate ):
@@ -281,7 +287,8 @@ def quantize_annotate_layer(to_annotate, quantize_config=None):
281
287
@metrics .MonitorBoolGauge ('quantize_apply_usage' )
282
288
def quantize_apply (
283
289
model ,
284
- scheme = default_8bit_quantize_scheme .Default8BitQuantizeScheme ()):
290
+ scheme = default_8bit_quantize_scheme .Default8BitQuantizeScheme (),
291
+ quantized_layer_name_prefix = 'quant_' ):
285
292
"""Quantize a `tf.keras` model that has been annotated for quantization.
286
293
287
294
Quantization constructs a model which emulates quantization during training.
@@ -319,6 +326,8 @@ def quantize_apply(
319
326
with `quantize_annotate`. It can have pre-trained weights.
320
327
scheme: A `QuantizeScheme` which specifies transformer and quantization
321
328
registry. The default is `Default8BitQuantizeScheme()`.
329
+ quantized_layer_name_prefix: A name prefix for quantized layers. The default
330
+ is `quant_`.
322
331
323
332
Returns:
324
333
Returns a new `tf.keras` model in which the annotated layers have been
@@ -327,6 +336,9 @@ def quantize_apply(
327
336
if model is None :
328
337
raise ValueError ('`model` cannot be None' )
329
338
339
+ if quantized_layer_name_prefix is None :
340
+ quantized_layer_name_prefix = ''
341
+
330
342
if not isinstance (model , keras .Model ):
331
343
raise ValueError ('`model` can only be a `tf.keras.Model` instance.'
332
344
'You passed an instance of type: {input}.' .format (
@@ -435,7 +447,7 @@ def _quantize(layer): # pylint: disable=missing-docstring
435
447
# `QuantizeAnnotate` wrapper may contain `batch_input_shape` like params.
436
448
# TODO(pulkitb): Ensure this does not affect model cloning.
437
449
return quantize_wrapper .QuantizeWrapperV2 (
438
- layer , quantize_config )
450
+ layer , quantize_config , name_prefix = quantized_layer_name_prefix )
439
451
440
452
# 1. Create a copy of the model with the same weights. This ensures
441
453
# modifications don't affect the original model, or its weights.
@@ -446,7 +458,7 @@ def _quantize(layer): # pylint: disable=missing-docstring
446
458
'Unable to clone model. This generally happens if you used custom '
447
459
'Keras layers or objects in your model. Please specify them via '
448
460
'`quantize_scope` for your calls to `quantize_model` and '
449
- '`quantize_apply`. [%s].' % er )
461
+ '`quantize_apply`. [%s].' % er ) from er
450
462
451
463
# 2. Remove QuantizeAnnotate wrappers from the layers in the model. This
452
464
# extracts the original model structure (easier to transform), and
0 commit comments