Skip to content

Commit e52c6f2

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Adding serialize/deserialize to Quantizers
PiperOrigin-RevId: 264460013
1 parent 9cb79a0 commit e52c6f2

File tree

2 files changed

+88
-6
lines changed

2 files changed

+88
-6
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quantizers.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,22 @@ def __call__(self, inputs, step, training, **kwargs):
5353
Returns: quantized tensor.
5454
"""
5555

56+
@abc.abstractmethod
57+
def get_config(self):
58+
raise NotImplementedError('Quantizer should implement get_config().')
59+
60+
@classmethod
61+
def from_config(cls, config):
62+
"""Instantiates a `Quantizer` from its config.
63+
64+
Args:
65+
config: Output of `get_config()`.
66+
67+
Returns:
68+
A `Quantizer` instance.
69+
"""
70+
return cls(**config)
71+
5672

5773
class LastValueQuantizer(Quantizer):
5874
"""Quantize tensor based on range the last batch of values."""
@@ -96,6 +112,24 @@ def __call__(self, inputs, step, training, **kwargs):
96112
# TODO(pulkitb): Figure out a clean way to use name_prefix here.
97113
)
98114

115+
def get_config(self):
116+
return {
117+
'num_bits': self.num_bits,
118+
'per_axis': self.per_axis,
119+
'symmetric': self.symmetric,
120+
}
121+
122+
def __eq__(self, other):
123+
if not isinstance(other, LastValueQuantizer):
124+
return False
125+
126+
return (self.num_bits == other.num_bits and
127+
self.per_axis == other.per_axis and
128+
self.symmetric == other.symmetric)
129+
130+
def __ne__(self, other):
131+
return not self.__eq__(other)
132+
99133

100134
class MovingAverageQuantizer(Quantizer):
101135
"""Quantize tensor based on a moving average of values across batches."""
@@ -137,3 +171,28 @@ def __call__(self, inputs, step, training, **kwargs):
137171
narrow_range=False,
138172
# TODO(pulkitb): Figure out a clean way to use name_prefix here.
139173
)
174+
175+
def get_config(self):
176+
return {
177+
'num_bits': self.num_bits,
178+
'per_axis': self.per_axis,
179+
'symmetric': self.symmetric,
180+
}
181+
182+
def __eq__(self, other):
183+
if not isinstance(other, MovingAverageQuantizer):
184+
return False
185+
186+
return (self.num_bits == other.num_bits and
187+
self.per_axis == other.per_axis and
188+
self.symmetric == other.symmetric)
189+
190+
def __ne__(self, other):
191+
return not self.__eq__(other)
192+
193+
194+
def _types_dict():
195+
return {
196+
'LastValueQuantizer': LastValueQuantizer,
197+
'MovingAverageQuantizer': MovingAverageQuantizer
198+
}

tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,25 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
from absl.testing import parameterized
22+
2123
import numpy as np
2224

2325
from tensorflow.python.client import session
2426
from tensorflow.python.framework import dtypes
2527
from tensorflow.python.framework import ops
28+
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
29+
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
2630
from tensorflow.python.ops import variable_scope
2731
from tensorflow.python.ops import variables
2832
from tensorflow.python.platform import test
2933

3034
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
3135

3236

33-
class QuantizersTest(test.TestCase):
37+
@parameterized.parameters(
38+
quantizers.LastValueQuantizer, quantizers.MovingAverageQuantizer)
39+
class QuantizersTest(test.TestCase, parameterized.TestCase):
3440

3541
def setUp(self):
3642
super(QuantizersTest, self).setUp()
@@ -63,15 +69,32 @@ def _test_quantizer(quantizer):
6369
print('min_var: ', min_max_values[0])
6470
print('max_var: ', min_max_values[1])
6571

66-
def testLastValueQuantizer(self):
67-
quantizer = quantizers.LastValueQuantizer(**self.quant_params)
72+
def testQuantizer(self, quantizer_type):
73+
quantizer = quantizer_type(**self.quant_params)
6874

6975
self._test_quantizer(quantizer)
7076

71-
def testMovingAverageQuantizer(self):
72-
quantizer = quantizers.MovingAverageQuantizer(**self.quant_params)
77+
def testSerialization(self, quantizer_type):
78+
quantizer = quantizer_type(**self.quant_params)
7379

74-
self._test_quantizer(quantizer)
80+
expected_config = {
81+
'class_name': quantizer_type.__name__,
82+
'config': {
83+
'num_bits': 8,
84+
'per_axis': False,
85+
'symmetric': False
86+
}
87+
}
88+
serialized_quantizer = serialize_keras_object(quantizer)
89+
90+
self.assertEqual(expected_config, serialized_quantizer)
91+
92+
quantizer_from_config = deserialize_keras_object(
93+
serialized_quantizer,
94+
module_objects=globals(),
95+
custom_objects=quantizers._types_dict())
96+
97+
self.assertEqual(quantizer, quantizer_from_config)
7598

7699

77100
if __name__ == '__main__':

0 commit comments

Comments
 (0)