Skip to content

Commit c8cce59

Browse files
daverimtensorflower-gardener
authored andcommitted
Add quantize_output property to Default8BitActivationQuantizeConfig and Default8BitOutputQuantizeConfig to allow loading output quantize disabled wrapped layers from disk
PiperOrigin-RevId: 432095782
1 parent 537cefb commit c8cce59

File tree

6 files changed

+161
-4
lines changed

6 files changed

+161
-4
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ py_strict_test(
113113
python_version = "PY3",
114114
deps = [
115115
":default_8bit_quantize_configs",
116+
":default_8bit_quantize_registry",
116117
":default_8bit_transforms",
117118
# absl/testing:parameterized dep1,
118119
# numpy dep1,

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_configs.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
class Default8BitOutputQuantizeConfig(quantize_config.QuantizeConfig):
2222
"""QuantizeConfig which only quantizes the output from a layer."""
2323

24+
def __init__(self, quantize_output: bool = True) -> None:
25+
self.quantize_output = quantize_output
26+
2427
def get_weights_and_quantizers(self, layer):
2528
return []
2629

@@ -34,11 +37,13 @@ def set_quantize_activations(self, layer, quantize_activations):
3437
pass
3538

3639
def get_output_quantizers(self, layer):
37-
return [quantizers.MovingAverageQuantizer(
38-
num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]
40+
if self.quantize_output:
41+
return [quantizers.MovingAverageQuantizer(
42+
num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]
43+
return []
3944

4045
def get_config(self):
41-
return {}
46+
return {'quantize_output': self.quantize_output}
4247

4348

4449
class NoOpQuantizeConfig(quantize_config.QuantizeConfig):

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,15 @@ class Default8BitActivationQuantizeConfig(QuantizeConfig):
463463
decision to quantize depends on the specific activation type.
464464
"""
465465

466+
def __init__(self, quantize_output=True):
467+
"""Construct a default QuantizeConfig for Activation layers.
468+
469+
Args:
470+
quantize_output: Enable quantization of output, used to disable during
471+
transform.
472+
"""
473+
self.quantize_output = quantize_output
474+
466475
def _assert_activation_layer(self, layer):
467476
if not isinstance(layer, layers.Activation):
468477
raise RuntimeError(
@@ -485,6 +494,8 @@ def set_quantize_activations(self, layer, quantize_activations):
485494

486495
def get_output_quantizers(self, layer):
487496
self._assert_activation_layer(layer)
497+
if not self.quantize_output:
498+
return []
488499

489500
if not hasattr(layer.activation, '__name__'):
490501
raise ValueError('Activation {} not supported by '
@@ -504,7 +515,11 @@ def get_output_quantizers(self, layer):
504515
layer.activation))
505516

506517
def get_config(self):
507-
return {}
518+
return {'quantize_output': self.quantize_output}
519+
520+
@classmethod
521+
def from_config(cls, config):
522+
return cls(**config)
508523

509524

510525
class Default8BitConvQuantizeConfig(Default8BitQuantizeConfig):

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,10 @@ def testReturnsActivationConfig_Activation(self):
237237
self._assert_activation_quantizers(
238238
quantize_config.get_output_quantizers(activation_layer))
239239

240+
quantize_config.quantize_output = False
241+
self.assertEmpty(
242+
quantize_config.get_output_quantizers(activation_layer))
243+
240244

241245
class Default8BitQuantizeConfigTest(tf.test.TestCase, _TestHelper):
242246

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,9 @@ def _get_layer_type(self, layer_class_name):
628628
def _disable_output_quantize(self, quantize_config):
629629
# TODO(pulkitb): Disabling quantize_config may also require handling
630630
# activation quantizers. Handle that properly.
631+
if hasattr(quantize_config, 'quantize_output'):
632+
quantize_config.quantize_output = False
633+
631634
quantize_config.get_output_quantizers = lambda layer: []
632635

633636
def replacement(self, match_layer):

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
2727
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
2828
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_configs
29+
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
2930
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_transforms
3031
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import model_transformer
3132
from tensorflow_model_optimization.python.core.quantization.keras.layers import conv_batchnorm_test_utils
@@ -576,6 +577,134 @@ def testConcatMultipleLevels(self):
576577
default_8bit_quantize_configs.Default8BitOutputQuantizeConfig)
577578
self.assertNotEmpty(quantize_config.get_output_quantizers(None))
578579

580+
def testConcatActivationTransform(self):
581+
r"""Tests the Concat Transform.
582+
583+
Input Input
584+
/ \
585+
Relu Relu
586+
\ /
587+
Concat
588+
589+
The Transform should ensure both the output FakeQuants are disabled,
590+
and only a FakeQuant after Concat is present.
591+
"""
592+
relu_1 = keras.layers.Activation('relu')
593+
relu_2 = keras.layers.Activation('relu')
594+
concat = keras.layers.Concatenate()
595+
596+
inp1 = keras.layers.Input((2,))
597+
inp2 = keras.layers.Input((2,))
598+
x1 = relu_1(inp1)
599+
x2 = relu_2(inp2)
600+
x = concat([x1, x2])
601+
model = keras.Model([inp1, inp2], x)
602+
603+
layer_metadata = {
604+
# dense_1 has an existing quantize_config.
605+
relu_1.name: {
606+
'quantize_config':
607+
(default_8bit_quantize_registry
608+
.Default8BitActivationQuantizeConfig())
609+
},
610+
relu_2.name: {
611+
'quantize_config':
612+
(default_8bit_quantize_registry
613+
.Default8BitActivationQuantizeConfig())
614+
}
615+
}
616+
_, updated_metadata = ModelTransformer(
617+
model, [default_8bit_transforms.ConcatTransform()],
618+
layer_metadata=layer_metadata).transform()
619+
620+
concat_quantize_config = updated_metadata.get(
621+
concat.name).get('quantize_config')
622+
# Concat should quantize the output.
623+
self.assertIsInstance(
624+
concat_quantize_config,
625+
default_8bit_quantize_configs.Default8BitOutputQuantizeConfig)
626+
self.assertNotEmpty(concat_quantize_config.get_output_quantizers(None))
627+
628+
relu_1_quantize_config = updated_metadata.get(
629+
relu_1.name).get('quantize_config')
630+
# The existing quantize_config should do nothing for outputs.
631+
self.assertIsInstance(
632+
relu_1_quantize_config,
633+
default_8bit_quantize_registry.Default8BitActivationQuantizeConfig)
634+
self.assertEmpty(relu_1_quantize_config.get_output_quantizers(None))
635+
self.assertFalse(relu_1_quantize_config.quantize_output)
636+
637+
relu_2_quantize_config = updated_metadata.get(
638+
relu_2.name).get('quantize_config')
639+
# The quantize_config from registry should do nothing at output.
640+
self.assertIsInstance(
641+
relu_1_quantize_config,
642+
default_8bit_quantize_registry.Default8BitActivationQuantizeConfig)
643+
self.assertEmpty(relu_2_quantize_config.get_output_quantizers(None))
644+
self.assertFalse(relu_2_quantize_config.quantize_output)
645+
646+
def testConcatConcatTransformDisablesOutput(self):
647+
r"""Tests the Concat Transform.
648+
649+
Input Input Input Input
650+
Reshape Reshape Reshape Reshape
651+
\ / \ /
652+
Concat Concat
653+
\ /
654+
Concat
655+
656+
The Transform should ensure all output FakeQuants are disabled,
657+
and only a FakeQuant after the last Concat is present.
658+
"""
659+
flatten_1 = keras.layers.Flatten()
660+
flatten_2 = keras.layers.Flatten()
661+
concat_1 = keras.layers.Concatenate()
662+
flatten_3 = keras.layers.Flatten()
663+
flatten_4 = keras.layers.Flatten()
664+
concat_2 = keras.layers.Concatenate()
665+
concat = keras.layers.Concatenate()
666+
667+
inp1 = keras.layers.Input((1, 2, 2))
668+
inp2 = keras.layers.Input((1, 2, 2))
669+
inp3 = keras.layers.Input((1, 2, 2))
670+
inp4 = keras.layers.Input((1, 2, 2))
671+
x1 = flatten_1(inp1)
672+
x2 = flatten_2(inp2)
673+
x3 = flatten_3(inp3)
674+
x4 = flatten_4(inp4)
675+
676+
y1 = concat_1([x1, x2])
677+
y2 = concat_2([x3, x4])
678+
z = concat([y1, y2])
679+
model = keras.Model([inp1, inp2, inp3, inp4], z)
680+
reshapes = [flatten_1, flatten_2, flatten_3, flatten_4]
681+
layer_metadata = {}
682+
for layer in reshapes:
683+
layer_metadata[layer.name] = {
684+
'quantize_config':
685+
default_8bit_quantize_registry.Default8BitQuantizeConfig(
686+
[], [], True)}
687+
_, updated_metadata = ModelTransformer(
688+
model, [default_8bit_transforms.ConcatTransform()],
689+
layer_metadata=layer_metadata).transform()
690+
691+
concat_quantize_config = updated_metadata.get(
692+
concat.name).get('quantize_config')
693+
# Concat should quantize the output.
694+
self.assertIsInstance(
695+
concat_quantize_config,
696+
default_8bit_quantize_configs.Default8BitOutputQuantizeConfig)
697+
self.assertNotEmpty(concat_quantize_config.get_output_quantizers(None))
698+
699+
# The existing quantize_config should do nothing for outputs.
700+
for layer in reshapes:
701+
quantize_config = updated_metadata.get(layer.name).get('quantize_config')
702+
self.assertIsInstance(
703+
quantize_config,
704+
default_8bit_quantize_registry.Default8BitQuantizeConfig)
705+
self.assertEmpty(quantize_config.get_output_quantizers(layer))
706+
self.assertFalse(quantize_config.quantize_output)
707+
579708

580709
if __name__ == '__main__':
581710
tf.test.main()

0 commit comments

Comments
 (0)