@@ -68,14 +68,35 @@ def _is_functional_model(model):
6868 and not isinstance (model , keras .Sequential ) \
6969 and model ._is_graph_network # pylint: disable=protected-access
7070
71+ def _inbound_node_generator (self , layer ):
72+ for inbound_node in layer ['inbound_nodes' ]:
73+ if len (inbound_node ) > 0 and isinstance (inbound_node [0 ], str ):
74+ # TODO(tfmot): The case for the SlicingOpLambda.
75+ yield [inbound_node ]
76+ else :
77+ yield inbound_node
78+
79+ def _get_inbound_layer_names (self , layer ):
80+ """Return all the inbound connection layer names for the layer."""
81+ inbound_layer_names = []
82+ for inbound_node in self ._inbound_node_generator (layer ):
83+ for connection_info in inbound_node :
84+ # input argument case.
85+ inbound_layer_names .append (connection_info [0 ])
86+ # **kwarg argument case.
87+ inbound_layer_names += [
88+ value [0 ] for value in connection_info [3 ].items ()]
89+
90+ return inbound_layer_names
91+
7192 def _get_consuming_layers (self , check_layer ):
7293 """Returns all the layers which are out nodes from the layer."""
7394 consuming_layers = []
95+ check_layer_name = check_layer ['config' ]['name' ]
7496 for layer in self ._config ['layers' ]:
75- for inbound_node in layer ['inbound_nodes' ]:
76- for connection_info in inbound_node :
77- if connection_info [0 ] == check_layer ['config' ]['name' ]:
78- consuming_layers .append (layer )
97+ if check_layer_name in self ._get_inbound_layer_names (layer ):
98+ consuming_layers .append (layer )
99+
79100 return consuming_layers
80101
81102 def _get_output_consumers (self , check_layer ):
@@ -292,11 +313,22 @@ def _replace_functional(self, match_layer_node, replacement_layer_node):
292313 # replaced layer should equal the original layer.
293314
294315 consuming_layers = self ._get_consuming_layers (match_layer_node .layer )
316+ match_name = match_layer_node .layer ['config' ]['name' ]
317+ replacement_name = replacement_layer_node .layer ['config' ]['name' ]
318+
319+ def _replace_layer_name_for_connection_info (
320+ connection_info , match_name , replacement_name ):
321+ if connection_info [0 ] == match_name :
322+ connection_info [0 ] = replacement_name
323+ for key in connection_info [3 ]:
324+ if connection_info [3 ][key ][0 ] == match_name :
325+ connection_info [3 ][key ][0 ] = replacement_name
326+
295327 for consumer in consuming_layers :
296- for inbound_node in consumer [ 'inbound_nodes' ] :
328+ for inbound_node in self . _inbound_node_generator ( consumer ) :
297329 for connection_info in inbound_node :
298- if connection_info [ 0 ] == match_layer_node . layer [ 'config' ][ 'name' ]:
299- connection_info [ 0 ] = replacement_layer_node . layer [ 'config' ][ 'name' ]
330+ _replace_layer_name_for_connection_info (
331+ connection_info , match_name , replacement_name )
300332
301333 output_consumers = self ._get_output_consumers (match_layer_node .layer )
302334 for output_consumer in output_consumers :
0 commit comments