Skip to content

Commit f251f71

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Only set weights returned by Transform.
Currently, ModelTransformer expects all the weights required by a layer to be returned in the Transform and set. This has issues since the Transform might not know all the weights, or might only want to set some of the weights, not all. Some of the weights are often created by the layer in the build method. Example: ConvBatchNorm layers create quant variables which are determined by the weights and quantizers. The Transform does not know about them. This CL allows the Transform to set specific weights it wants. Sets only the map values, not the entire list. PiperOrigin-RevId: 278967295
1 parent 0b4eb1c commit f251f71

File tree

2 files changed

+32
-22
lines changed

2 files changed

+32
-22
lines changed

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import copy
2020

2121
from tensorflow.python import keras
22+
from tensorflow.python.keras import backend as K
2223

2324
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms as transforms_mod
2425

@@ -302,15 +303,41 @@ def _add_replacement_layer(layer_node):
302303

303304
_add_replacement_layer(replacement_layer_node)
304305

306+
@staticmethod
307+
def _weight_name(name):
308+
"""Extracts the weight name by removing layer from TF variable name.
309+
310+
For example, returns 'kernel:0' for 'dense_2/kernel:0'.
311+
312+
Args:
313+
name: TensorFlow variable name.
314+
315+
Returns:
316+
Extracted weight name.
317+
"""
318+
return name.split('/')[-1]
319+
305320
def _get_keras_layer_weights(self, keras_layer):
306321
"""Returns a map of weight name, weight matrix. Keeps keras ordering."""
307322
weights_map = collections.OrderedDict()
308323
for weight_tensor, weight_numpy in \
309324
zip(keras_layer.weights, keras_layer.get_weights()):
310-
weights_map[weight_tensor.name] = weight_numpy
325+
weights_map[self._weight_name(weight_tensor.name)] = weight_numpy
311326

312327
return weights_map
313328

329+
def _set_layer_weights(self, layer, weights_map):
330+
"""Sets the values of weights in a Keras layer."""
331+
332+
weight_value_tuples = []
333+
for weight_tensor in layer.weights:
334+
weight_name = self._weight_name(weight_tensor.name)
335+
if weight_name in weights_map:
336+
weight_value_tuples.append(
337+
(weight_tensor, weights_map[weight_name]))
338+
339+
K.batch_set_value(weight_value_tuples)
340+
314341
def transform(self):
315342
"""Transforms the Keras model by applying all the specified transforms.
316343
@@ -390,7 +417,7 @@ def transform(self):
390417
for layer in transformed_model.layers:
391418
weights = self._layer_weights_map.get(layer.name)
392419
if weights:
393-
layer.set_weights(list(weights.values()))
420+
self._set_layer_weights(layer, weights)
394421

395422
# TODO(pulkitb): Consider returning the updated metadata for the
396423
# transformed model along with the model. This allows the opportunity for

tensorflow_model_optimization/python/core/quantization/keras/tflite/tflite_transforms.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,27 +37,10 @@ def _get_conv_bn_layers(bn_layer_node):
3737

3838
def _get_weights(bn_layer_node):
3939
"""Returns weight values for fused layer, including copying original values in unfused version."""
40-
weights = collections.OrderedDict()
41-
42-
bn_layer_weights = list(bn_layer_node.weights.items())
43-
conv_layer_weights = list(bn_layer_node.input_layers[0].weights.items())
44-
45-
weights['conv/kernel'] = conv_layer_weights[0][1]
46-
weights['batch_normalization/gamma:0'] = bn_layer_weights[0][1]
47-
weights['batch_normalization/beta:0'] = bn_layer_weights[1][1]
48-
49-
# TODO(tfmot): remove hardcoded initialization values.
50-
weights['weight_min'] = np.array(-6.0)
51-
weights['weight_max'] = np.array(6.0)
52-
weights['optimizer_step'] = np.array(-1)
53-
weights['activation_min'] = np.array(-6.0)
54-
weights['activation_max'] = np.array(6.0)
55-
56-
weights['batch_normalization/moving_mean:0'] = bn_layer_weights[2][1]
57-
weights['batch_normalization/moving_variance:0'] = bn_layer_weights[3][1]
58-
59-
return weights
6040

41+
return collections.OrderedDict(
42+
list(bn_layer_node.input_layers[0].weights.items())
43+
+ list(bn_layer_node.weights.items()))
6144

6245
def _get_params(conv_layer, bn_layer, relu_layer=None):
6346
"""Retrieve conv_bn params within wrapped layers."""

0 commit comments

Comments
 (0)