Skip to content

Commit 3c4f3b2

Browse files
daverimtensorflower-gardener
authored andcommitted
When using annotate_quantize_layer, the expectation after conversion to TFLite is that the annotated layer should be quantized. However, this is not true when the input layer is not quantized.
Add a wrapper config for layers which are not annotated but consumed by an annotated layer. PiperOrigin-RevId: 364984791
1 parent eaeb6e7 commit 3c4f3b2

File tree

2 files changed

+114
-7
lines changed

2 files changed

+114
-7
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quantize.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,22 @@ def _clone_model_with_weights(model_to_clone):
341341
def _extract_original_model(model_to_unwrap):
342342
"""Extracts original model by removing wrappers."""
343343
layer_quantize_map = {}
344+
requires_output_quantize = set()
344345

345346
def _unwrap(layer):
346347
if not isinstance(layer, quantize_annotate_mod.QuantizeAnnotate):
347348
return layer
348349

349350
annotate_wrapper = layer
351+
# pylint: disable=protected-access
352+
if layer._inbound_nodes and len(layer._inbound_nodes) == 1:
353+
node = layer._inbound_nodes[0]
354+
inbound_layers = tf.nest.flatten(node.inbound_layers)
355+
if len(inbound_layers) == 1 and not isinstance(
356+
inbound_layers[0], quantize_annotate_mod.QuantizeAnnotate):
357+
requires_output_quantize.add(inbound_layers[0].name)
358+
# pylint: enable=protected-access
359+
350360
layer_quantize_map[annotate_wrapper.layer.name] = {
351361
'quantize_config': annotate_wrapper.quantize_config
352362
}
@@ -355,15 +365,53 @@ def _unwrap(layer):
355365
unwrapped_model = keras.models.clone_model(
356366
model_to_unwrap, input_tensors=None, clone_function=_unwrap)
357367

358-
return unwrapped_model, layer_quantize_map
368+
return unwrapped_model, layer_quantize_map, requires_output_quantize
369+
370+
class OutputOnlyConfig(quantize_config_mod.QuantizeConfig):
371+
"""QuantizeConfig that only quantizes output."""
372+
373+
def __init__(self, quantize_config):
374+
self.quantize_config = quantize_config
375+
376+
def get_weights_and_quantizers(self, layer):
377+
return []
378+
379+
def set_quantize_weights(self, layer, quantize_weights):
380+
pass
381+
382+
def get_activations_and_quantizers(self, layer):
383+
return self.quantize_config.get_activations_and_quantizers(layer)
384+
385+
def set_quantize_activations(self, layer, quantize_activations):
386+
return self.quantize_config.set_quantize_activations(
387+
layer, quantize_activations)
388+
389+
def get_output_quantizers(self, layer):
390+
return self.quantize_config.get_output_quantizers(layer)
391+
392+
def get_config(self):
393+
return {'quantize_config': self.quantize_config}
394+
395+
@classmethod
396+
def from_config(cls, config):
397+
return cls(**config)
359398

360399
def _quantize(layer): # pylint: disable=missing-docstring
361-
if layer.name not in layer_quantize_map:
400+
if (layer.name not in layer_quantize_map and
401+
layer.name not in requires_output_quantize):
362402
return layer
363403

364-
quantize_config = layer_quantize_map[layer.name].get('quantize_config')
365-
if not quantize_config and quantize_registry.supports(layer):
366-
quantize_config = quantize_registry.get_quantize_config(layer)
404+
if layer.name in requires_output_quantize:
405+
if not quantize_registry.supports(layer):
406+
return layer
407+
full_quantize_config = quantize_registry.get_quantize_config(layer)
408+
if not full_quantize_config:
409+
return layer
410+
quantize_config = OutputOnlyConfig(full_quantize_config)
411+
else:
412+
quantize_config = layer_quantize_map[layer.name].get('quantize_config')
413+
if not quantize_config and quantize_registry.supports(layer):
414+
quantize_config = quantize_registry.get_quantize_config(layer)
367415

368416
if not quantize_config:
369417
error_msg = (
@@ -395,7 +443,8 @@ def _quantize(layer): # pylint: disable=missing-docstring
395443
# 2. Remove QuantizeAnnotate wrappers from the layers in the model. This
396444
# extracts the original model structure (easier to transform), and
397445
# stores relevant quantization information in a map.
398-
unwrapped_model, layer_quantize_map = _extract_original_model(model_copy)
446+
(unwrapped_model, layer_quantize_map,
447+
requires_output_quantize) = _extract_original_model(model_copy)
399448
# Model cloning excludes input layers. Add input layers into the map
400449
# since they need to be matched for patterns as well.
401450
# pylint: disable=protected-access

tensorflow_model_optimization/python/core/quantization/keras/quantize_test.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,44 @@ def _assert_model_quantized(
288288
zip(annotated_model.layers, quantized_model.layers):
289289

290290
if not isinstance(layer_annotated, QuantizeAnnotate):
291-
self.assertNotIsInstance(layer_quantized, QuantizeWrapper)
291+
# Possibly wrapped for input quantization.
292+
if isinstance(layer_quantized, QuantizeWrapper):
293+
self.assertLen(layer_annotated._outbound_nodes, 1)
294+
self.assertIsInstance(
295+
layer_annotated._outbound_nodes[0].outbound_layer,
296+
QuantizeAnnotate)
297+
298+
# Ensure that only outputs are quantized.
299+
self.assertFalse(
300+
layer_quantized.quantize_config.get_weights_and_quantizers(
301+
layer_quantized.layer))
292302
continue
293303

294304
self._assert_layer_quantized(
295305
layer_annotated, layer_quantized, exclude_keys)
296306

307+
def _assert_nonannotated_input_layer_quantized(
308+
self, quantized_model, layer_index):
309+
output_quantized_layer = quantized_model.layers[layer_index]
310+
self.assertIsInstance(output_quantized_layer, QuantizeWrapper)
311+
output_quantized_config = output_quantized_layer.quantize_config
312+
default_quantized_config = output_quantized_config.get_config()[
313+
'quantize_config']
314+
self.assertIsInstance(
315+
default_quantized_config,
316+
default_8bit_quantize_registry.Default8BitQuantizeConfig)
317+
self.assertFalse(output_quantized_config.get_weights_and_quantizers(
318+
output_quantized_layer.layer))
319+
self.assertTrue(default_quantized_config.get_weights_and_quantizers(
320+
output_quantized_layer.layer))
321+
output_quantizers = output_quantized_config.get_output_quantizers(
322+
output_quantized_layer.layer)
323+
default_output_quantizers = default_quantized_config.get_output_quantizers(
324+
output_quantized_layer.layer)
325+
for output_quantizer, default_quantizer in zip(output_quantizers,
326+
default_output_quantizers):
327+
self.assertEqual(output_quantizer, default_quantizer)
328+
297329
# quantize_apply Tests
298330

299331
class CustomLayer(keras.layers.Dense):
@@ -379,6 +411,18 @@ def testAppliesQuantizationToAnnotatedModel_Sequential(self):
379411

380412
self._assert_model_quantized(model, quantized_model, ['activation'])
381413

414+
def testAppliesQuantizationToInputsToAnnotatedModel_Sequential(self):
415+
model = keras.Sequential([
416+
keras.layers.Conv2D(32, 5, input_shape=(28, 28, 1), activation='relu'),
417+
keras.layers.Dense(10, activation='relu'),
418+
quantize_annotate_layer(keras.layers.Dense(5, activation='softmax')),
419+
])
420+
quantized_model = quantize_apply(model)
421+
self._assert_model_quantized(model, quantized_model, ['activation'])
422+
# Test that Dense layer has output only quantization config.
423+
self._assert_nonannotated_input_layer_quantized(
424+
quantized_model, layer_index=1)
425+
382426
def testAppliesQuantizationToAnnotatedModel_PreservesBuiltState(self):
383427
model = keras_test_utils.build_simple_dense_model()
384428
annotated_model = quantize_annotate_model(model)
@@ -404,6 +448,20 @@ def testAppliesQuantizationToAnnotatedModel_Functional(self):
404448

405449
self._assert_model_quantized(model, quantized_model, ['activation'])
406450

451+
def testAppliesQuantizationToInputsToAnnotatedModel_Functional(self):
452+
inputs = keras.Input(shape=(28, 28, 1))
453+
x = keras.layers.Conv2D(32, 5, activation='relu')(inputs)
454+
x = keras.layers.Dense(10, activation='relu')(x)
455+
results = quantize_annotate_layer(
456+
keras.layers.Dense(5, activation='softmax'))(
457+
x)
458+
model = keras.Model(inputs=inputs, outputs=results)
459+
quantized_model = quantize_apply(model)
460+
self._assert_model_quantized(model, quantized_model, ['activation'])
461+
# Test that Dense layer has output only quantization config.
462+
self._assert_nonannotated_input_layer_quantized(
463+
quantized_model, layer_index=2)
464+
407465
def testDoesNotQuantizeInputLayer_OutboundLayerNotQuantized(self):
408466
model = self._get_simple_functional_model()
409467

0 commit comments

Comments
 (0)