Skip to content

Commit b61065a

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Serialize/Deserialize support for QuantizeProviders
PiperOrigin-RevId: 264463346
1 parent e52c6f2 commit b61065a

File tree

5 files changed

+96
-0
lines changed

5 files changed

+96
-0
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def set_quantize_weights(self, layer, quantize_weights):
4343
def set_quantize_activations(self, layer, quantize_activations):
4444
pass
4545

46+
def get_config(self):
47+
pass
48+
4649
def testAnnotatesKerasLayer(self):
4750
layer = keras.layers.Dense(5, activation='relu', input_shape=(10,))
4851
model = keras.Sequential([layer])

tensorflow_model_optimization/python/core/quantization/keras/quantize_emulate_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ def set_quantize_weights(self, layer, quantize_weights):
117117
def set_quantize_activations(self, layer, quantize_activations):
118118
pass
119119

120+
def get_config(self):
121+
pass
122+
120123
quantize_provider = TestQuantizeProvider()
121124

122125
model = keras.Sequential([

tensorflow_model_optimization/python/core/quantization/keras/quantize_provider.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,7 @@ def set_quantize_activations(self, layer, quantize_activations):
7676
quantizer.
7777
"""
7878
raise NotImplementedError('Must be implemented in subclasses.')
79+
80+
@abc.abstractmethod
81+
def get_config(self):
82+
raise NotImplementedError('QuantizeProvider should implement get_config().')

tensorflow_model_optimization/python/core/quantization/keras/tflite/tflite_quantize_registry.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,39 @@ def set_quantize_activations(self, layer, quantize_activations):
333333
zip(self.activation_attrs, quantize_activations):
334334
setattr(layer, activation_attr, activation)
335335

336+
@classmethod
337+
def from_config(cls, config):
338+
"""Instantiates a `TFLiteQuantizeProvider` from its config.
339+
340+
Args:
341+
config: Output of `get_config()`.
342+
343+
Returns:
344+
A `TFLiteQuantizeProvider` instance.
345+
"""
346+
return cls(**config)
347+
348+
def get_config(self):
349+
# TODO(pulkitb): Add weight and activation quantizer to config.
350+
# Currently it's created internally, but ideally the quantizers should be
351+
# part of the constructor and passed in from the registry.
352+
return {
353+
'weight_attrs': self.weight_attrs,
354+
'activation_attrs': self.activation_attrs,
355+
}
356+
357+
def __eq__(self, other):
358+
if not isinstance(other, TFLiteQuantizeProvider):
359+
return False
360+
361+
return (self.weight_attrs == other.weight_attrs and
362+
self.activation_attrs == self.activation_attrs and
363+
self.weight_quantizer == other.weight_quantizer and
364+
self.activation_quantizer == other.activation_quantizer)
365+
366+
def __ne__(self, other):
367+
return not self.__eq__(other)
368+
336369

337370
class TFLiteQuantizeProviderRNN(TFLiteQuantizeProvider, _RNNHelper):
338371
"""QuantizeProvider for RNN layers."""
@@ -402,3 +435,10 @@ def set_quantize_activations(self, layer, quantize_activations):
402435
for activation_attr in activation_attrs_cell:
403436
setattr(rnn_cell, activation_attr, quantize_activations[i])
404437
i += 1
438+
439+
440+
def _types_dict():
441+
return {
442+
'TFLiteQuantizeProvider': TFLiteQuantizeProvider,
443+
'TFLiteQuantizeProviderRNN': TFLiteQuantizeProviderRNN
444+
}

tensorflow_model_optimization/python/core/quantization/keras/tflite/tflite_quantize_registry_test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from tensorflow.python import keras
2424
from tensorflow.python.keras import backend as K
2525
from tensorflow.python.keras import layers as l
26+
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
27+
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
2628
from tensorflow.python.platform import test
2729

2830
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
@@ -271,6 +273,28 @@ def testSetsQuantizeActivations_ErrorOnWrongNumberOfActivations(self):
271273
quantize_provider.set_quantize_activations(
272274
layer, [quantize_activation, quantize_activation])
273275

276+
def testSerialization(self):
277+
quantize_provider = tflite_quantize_registry.TFLiteQuantizeProvider(
278+
['kernel'], ['activation'])
279+
280+
expected_config = {
281+
'class_name': 'TFLiteQuantizeProvider',
282+
'config': {
283+
'weight_attrs': ['kernel'],
284+
'activation_attrs': ['activation'],
285+
}
286+
}
287+
serialized_quantize_provider = serialize_keras_object(quantize_provider)
288+
289+
self.assertEqual(expected_config, serialized_quantize_provider)
290+
291+
quantize_provider_from_config = deserialize_keras_object(
292+
serialized_quantize_provider,
293+
module_objects=globals(),
294+
custom_objects=tflite_quantize_registry._types_dict())
295+
296+
self.assertEqual(quantize_provider, quantize_provider_from_config)
297+
274298

275299
class TFLiteQuantizeProviderRNNTest(test.TestCase, _TestHelper):
276300

@@ -367,6 +391,28 @@ def testSetsQuantizeActivations_ErrorOnWrongNumberOfActivations(self):
367391
self.quantize_provider.set_quantize_activations(
368392
self.layer, [quantize_activation, quantize_activation])
369393

394+
def testSerialization(self):
395+
expected_config = {
396+
'class_name': 'TFLiteQuantizeProviderRNN',
397+
'config': {
398+
'weight_attrs': [['kernel', 'recurrent_kernel'],
399+
['kernel', 'recurrent_kernel']],
400+
'activation_attrs': [['activation', 'recurrent_activation'],
401+
['activation', 'recurrent_activation']],
402+
}
403+
}
404+
serialized_quantize_provider = serialize_keras_object(
405+
self.quantize_provider)
406+
407+
self.assertEqual(expected_config, serialized_quantize_provider)
408+
409+
quantize_provider_from_config = deserialize_keras_object(
410+
serialized_quantize_provider,
411+
module_objects=globals(),
412+
custom_objects=tflite_quantize_registry._types_dict())
413+
414+
self.assertEqual(self.quantize_provider, quantize_provider_from_config)
415+
370416

371417
if __name__ == '__main__':
372418
test.main()

0 commit comments

Comments
 (0)