@@ -95,6 +95,9 @@ def _get_layers(self, layer_names):
9595 def _get_layer_weights (self , layer_name ):
9696 return self ._layer_weights_map .get (layer_name , {})
9797
98+ def _get_layer_names_and_weights (self , layer_name ):
99+ return self ._layer_names_and_weights_map .get (layer_name , {})
100+
98101 def _get_layer_metadata (self , layer_name ):
99102 return self ._layer_metadata_map .get (layer_name , {})
100103
@@ -191,7 +194,9 @@ def _match_layer_with_inputs(self, layer, pattern, is_head_node):
191194 if len (pattern .inputs ) == 0 :
192195 # Leaf layer in pattern.
193196 return LayerNode (layer , self ._get_layer_weights (layer ['config' ]['name' ]),
194- [], self ._get_layer_metadata (layer ['config' ]['name' ]))
197+ [], self ._get_layer_metadata (layer ['config' ]['name' ]),
198+ self ._get_layer_names_and_weights (
199+ layer ['config' ]['name' ]))
195200
196201 # There is a possible edge case where a single layer may output multiple
197202 # tensors and multiple tensors from that layer may be used by the
@@ -221,7 +226,8 @@ def _match_layer_with_inputs(self, layer, pattern, is_head_node):
221226
222227 return LayerNode (layer , self ._get_layer_weights (layer ['config' ]['name' ]),
223228 input_match_layer_nodes ,
224- self ._get_layer_metadata (layer ['config' ]['name' ]))
229+ self ._get_layer_metadata (layer ['config' ]['name' ]),
230+ self ._get_layer_names_and_weights (layer ['config' ]['name' ]))
225231
226232 def _find_pattern (self , pattern , matched_layers = None ):
227233 for layer in self ._config ['layers' ]:
@@ -265,6 +271,7 @@ def _remove_layers(self, layers_to_remove, layers_to_remove_names):
265271 # now that layer has been removed.
266272 for layer_name in layers_to_remove_names :
267273 self ._layer_weights_map .pop (layer_name , None )
274+ self ._layer_names_and_weights_map .pop (layer_name , None )
268275 self ._layer_metadata_map .pop (layer_name , None )
269276
270277 def _replace (self , match_layer_node , replacement_layer_node ):
@@ -355,8 +362,12 @@ def _add_replacement_layer(layer_node):
355362 """Recursively add new layers."""
356363 self ._config ['layers' ].append (layer_node .layer )
357364 layer_name = layer_node .layer ['config' ]['name' ]
365+ # TODO(b/184603494): Remove weight map structure from model_transformer.
358366 if layer_node .weights :
359367 self ._layer_weights_map [layer_name ] = layer_node .weights
368+ if layer_node .names_and_weights :
369+ self ._layer_names_and_weights_map [
370+ layer_name ] = layer_node .names_and_weights
360371 if layer_node .metadata :
361372 self ._layer_metadata_map [layer_name ] = layer_node .metadata
362373 if self .candidate_layers :
@@ -403,6 +414,9 @@ def _add_replacement_nodes(first_layer_removed_index, replacement_nodes):
403414 layer_name = replacement_node .layer ['config' ]['name' ]
404415 if replacement_node .weights :
405416 self ._layer_weights_map [layer_name ] = replacement_node .weights
417+ if replacement_node .names_and_weights :
418+ self ._layer_names_and_weights_map [
419+ layer_name ] = replacement_node .names_and_weights
406420 if replacement_node .metadata :
407421 self ._layer_metadata_map [layer_name ] = replacement_node .metadata
408422 if self .candidate_layers :
@@ -433,8 +447,17 @@ def _get_keras_layer_weights(self, keras_layer):
433447 zip (keras_layer .weights , keras_layer .get_weights ()):
434448 weights_map [self ._weight_name (weight_tensor .name )] = weight_numpy
435449
450+ if len (weights_map ) != len (keras_layer .weights ):
451+ # The case that variable identifier is not unique. It's a fallback that
452+ # uses weight list instead of the weights map.
453+ return None
454+
436455 return weights_map
437456
457+ def _get_keras_layer_names_and_weights (self , keras_layer ):
458+ return zip ([weight .name for weight in keras_layer .weights ],
459+ keras_layer .get_weights ())
460+
438461 def _set_layer_weights (self , layer , weights_map ):
439462 """Sets the values of weights in a Keras layer."""
440463
@@ -447,6 +470,9 @@ def _set_layer_weights(self, layer, weights_map):
447470
448471 K .batch_set_value (weight_value_tuples )
449472
473+ def _set_layer_names_and_weights (self , layer , names_and_weights ):
474+ layer .set_weights ([weight for _ , weight in names_and_weights ])
475+
450476 @staticmethod
451477 def _name (obj ):
452478 return obj .__class__ .__name__
@@ -496,8 +522,12 @@ def transform(self):
496522 # to prevent infinite loops.
497523 self ._transform_matched_layers_map = {}
498524 self ._layer_weights_map = {}
525+ self ._layer_names_and_weights_map = {}
526+
499527 for layer in self .model .layers :
500528 self ._layer_weights_map [layer .name ] = self ._get_keras_layer_weights (layer )
529+ self ._layer_names_and_weights_map [
530+ layer .name ] = self ._get_keras_layer_names_and_weights (layer )
501531
502532 # Maintains a current mutable copy of the metadata through transformation.
503533 self ._layer_metadata_map = copy .deepcopy (self .layer_metadata )
@@ -557,5 +587,9 @@ def transform(self):
557587 weights = self ._layer_weights_map .get (layer .name )
558588 if weights :
559589 self ._set_layer_weights (layer , weights )
590+ else :
591+ names_and_weights = self ._layer_names_and_weights_map .get (layer .name )
592+ if names_and_weights :
593+ self ._set_layer_names_and_weights (layer , names_and_weights )
560594
561595 return transformed_model , copy .deepcopy (self ._layer_metadata_map )
0 commit comments