Skip to content

Commit ea8e8d4

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Update QuantizeProvider to support setting weights/activations.
QuantizeProvider now supports setting weights and activations after the layer has been quantized. Handles this for regular and RNN layers. Added TFLite support for it, and additional tests to ensure all cases are taken care of. PiperOrigin-RevId: 264441143
1 parent a0b2291 commit ea8e8d4

File tree

5 files changed

+346
-32
lines changed

5 files changed

+346
-32
lines changed

tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ def get_weights_and_quantizers(self, layer):
3737
def get_activations_and_quantizers(self, layer):
3838
pass
3939

40+
def set_quantize_weights(self, layer, quantize_weights):
41+
pass
42+
43+
def set_quantize_activations(self, layer, quantize_activations):
44+
pass
45+
4046
def testAnnotatesKerasLayer(self):
4147
layer = keras.layers.Dense(5, activation='relu', input_shape=(10,))
4248
model = keras.Sequential([layer])

tensorflow_model_optimization/python/core/quantization/keras/quantize_emulate_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ def get_weights_and_quantizers(self, layer):
111111
def get_activations_and_quantizers(self, layer):
112112
pass
113113

114+
def set_quantize_weights(self, layer, quantize_weights):
115+
pass
116+
117+
def set_quantize_activations(self, layer, quantize_activations):
118+
pass
119+
114120
quantize_provider = TestQuantizeProvider()
115121

116122
model = keras.Sequential([

tensorflow_model_optimization/python/core/quantization/keras/quantize_provider.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,32 @@ def get_activations_and_quantizers(self, layer):
4747
quantizer.
4848
"""
4949
raise NotImplementedError('Must be implemented in subclasses.')
50+
51+
@abc.abstractmethod
52+
def set_quantize_weights(self, layer, quantize_weights):
53+
"""Replace the weights in the layer with quantized weights.
54+
55+
Args:
56+
layer: layer being quantized.
57+
quantize_weights: List of quantized weight tensors.
58+
59+
Returns:
60+
List of 2-tuples. Each tuple is a weight tensor and an associated
61+
quantizer.
62+
"""
63+
raise NotImplementedError('Must be implemented in subclasses.')
64+
65+
@abc.abstractmethod
66+
def set_quantize_activations(self, layer, quantize_activations):
67+
"""Replace the activations in the layer with quantized activations.
68+
69+
Args:
70+
layer: layer being quantized.
71+
quantize_activations: List of `QuantizeAwareActivation`s to replace
72+
layer activations.
73+
74+
Returns:
75+
List of 2-tuples. Each tuple is a keras activation and an associated
76+
quantizer.
77+
"""
78+
raise NotImplementedError('Must be implemented in subclasses.')

tensorflow_model_optimization/python/core/quantization/keras/tflite/tflite_quantize_registry.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,34 @@ def get_activations_and_quantizers(self, layer):
305305
return [(getattr(layer, activation_attr), self.activation_quantizer)
306306
for activation_attr in self.activation_attrs]
307307

308+
def set_quantize_weights(self, layer, quantize_weights):
309+
if len(self.weight_attrs) != len(quantize_weights):
310+
raise ValueError(
311+
'`set_quantize_weights` called on layer {} with {} '
312+
'weight parameters, but layer expects {} values.'.format(
313+
layer.name, len(quantize_weights), len(self.weight_attrs)))
314+
315+
for weight_attr, weight in zip(self.weight_attrs, quantize_weights):
316+
current_weight = getattr(layer, weight_attr)
317+
if current_weight.shape != weight.shape:
318+
raise ValueError('Existing layer weight shape {} is incompatible with'
319+
'provided weight shape {}'.format(
320+
current_weight.shape, weight.shape))
321+
322+
setattr(layer, weight_attr, weight)
323+
324+
def set_quantize_activations(self, layer, quantize_activations):
325+
if len(self.activation_attrs) != len(quantize_activations):
326+
raise ValueError(
327+
'`set_quantize_activations` called on layer {} with {} '
328+
'activation parameters, but layer expects {} values.'.format(
329+
layer.name, len(quantize_activations),
330+
len(self.activation_attrs)))
331+
332+
for activation_attr, activation in \
333+
zip(self.activation_attrs, quantize_activations):
334+
setattr(layer, activation_attr, activation)
335+
308336

309337
class TFLiteQuantizeProviderRNN(TFLiteQuantizeProvider, _RNNHelper):
310338
"""QuantizeProvider for RNN layers."""
@@ -328,3 +356,49 @@ def get_activations_and_quantizers(self, layer):
328356
(getattr(rnn_cell, activation_attr), self.activation_quantizer))
329357

330358
return activations_quantizers
359+
360+
def _flatten(self, list_of_lists):
361+
flat_list = []
362+
for sublist in list_of_lists:
363+
for item in sublist:
364+
flat_list.append(item)
365+
return flat_list
366+
367+
def set_quantize_weights(self, layer, quantize_weights):
368+
flattened_weight_attrs = self._flatten(self.weight_attrs)
369+
if len(flattened_weight_attrs) != len(quantize_weights):
370+
raise ValueError(
371+
'`set_quantize_weights` called on layer {} with {} '
372+
'weight parameters, but layer expects {} values.'.format(
373+
layer.name, len(quantize_weights), len(flattened_weight_attrs)))
374+
375+
i = 0
376+
for weight_attrs_cell, rnn_cell in \
377+
zip(self.weight_attrs, self._get_rnn_cells(layer)):
378+
for weight_attr in weight_attrs_cell:
379+
current_weight = getattr(rnn_cell, weight_attr)
380+
quantize_weight = quantize_weights[i]
381+
382+
if current_weight.shape != quantize_weight.shape:
383+
raise ValueError('Existing layer weight shape {} is incompatible with'
384+
'provided weight shape {}'.format(
385+
current_weight.shape, quantize_weight.shape))
386+
387+
setattr(rnn_cell, weight_attr, quantize_weight)
388+
i += 1
389+
390+
def set_quantize_activations(self, layer, quantize_activations):
391+
flattened_activation_attrs = self._flatten(self.activation_attrs)
392+
if len(flattened_activation_attrs) != len(quantize_activations):
393+
raise ValueError(
394+
'`set_quantize_activations` called on layer {} with {} '
395+
'activation parameters, but layer expects {} values.'.format(
396+
layer.name, len(quantize_activations),
397+
len(flattened_activation_attrs)))
398+
399+
i = 0
400+
for activation_attrs_cell, rnn_cell in \
401+
zip(self.activation_attrs, self._get_rnn_cells(layer)):
402+
for activation_attr in activation_attrs_cell:
403+
setattr(rnn_cell, activation_attr, quantize_activations[i])
404+
i += 1

0 commit comments

Comments
 (0)