|
19 | 19 | import copy
|
20 | 20 |
|
21 | 21 | from tensorflow.python import keras
|
| 22 | +from tensorflow.python.keras import backend as K |
22 | 23 |
|
23 | 24 | from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms as transforms_mod
|
24 | 25 |
|
@@ -302,15 +303,41 @@ def _add_replacement_layer(layer_node):
|
302 | 303 |
|
303 | 304 | _add_replacement_layer(replacement_layer_node)
|
304 | 305 |
|
| 306 | + @staticmethod |
| 307 | + def _weight_name(name): |
| 308 | + """Extracts the weight name by removing layer from TF variable name. |
| 309 | +
|
| 310 | + For example, returns 'kernel:0' for 'dense_2/kernel:0'. |
| 311 | +
|
| 312 | + Args: |
| 313 | + name: TensorFlow variable name. |
| 314 | +
|
| 315 | + Returns: |
| 316 | + Extracted weight name. |
| 317 | + """ |
| 318 | + return name.split('/')[-1] |
| 319 | + |
305 | 320 | def _get_keras_layer_weights(self, keras_layer):
|
306 | 321 | """Returns a map of weight name, weight matrix. Keeps keras ordering."""
|
307 | 322 | weights_map = collections.OrderedDict()
|
308 | 323 | for weight_tensor, weight_numpy in \
|
309 | 324 | zip(keras_layer.weights, keras_layer.get_weights()):
|
310 |
| - weights_map[weight_tensor.name] = weight_numpy |
| 325 | + weights_map[self._weight_name(weight_tensor.name)] = weight_numpy |
311 | 326 |
|
312 | 327 | return weights_map
|
313 | 328 |
|
| 329 | + def _set_layer_weights(self, layer, weights_map): |
| 330 | + """Sets the values of weights in a Keras layer.""" |
| 331 | + |
| 332 | + weight_value_tuples = [] |
| 333 | + for weight_tensor in layer.weights: |
| 334 | + weight_name = self._weight_name(weight_tensor.name) |
| 335 | + if weight_name in weights_map: |
| 336 | + weight_value_tuples.append( |
| 337 | + (weight_tensor, weights_map[weight_name])) |
| 338 | + |
| 339 | + K.batch_set_value(weight_value_tuples) |
| 340 | + |
314 | 341 | def transform(self):
|
315 | 342 | """Transforms the Keras model by applying all the specified transforms.
|
316 | 343 |
|
@@ -390,7 +417,7 @@ def transform(self):
|
390 | 417 | for layer in transformed_model.layers:
|
391 | 418 | weights = self._layer_weights_map.get(layer.name)
|
392 | 419 | if weights:
|
393 |
| - layer.set_weights(list(weights.values())) |
| 420 | + self._set_layer_weights(layer, weights) |
394 | 421 |
|
395 | 422 | # TODO(pulkitb): Consider returning the updated metadata for the
|
396 | 423 | # transformed model along with the model. This allows the opportunity for
|
|
0 commit comments