@@ -68,14 +68,35 @@ def _is_functional_model(model):
68
68
and not isinstance (model , keras .Sequential ) \
69
69
and model ._is_graph_network # pylint: disable=protected-access
70
70
71
+ def _inbound_node_generator (self , layer ):
72
+ for inbound_node in layer ['inbound_nodes' ]:
73
+ if len (inbound_node ) > 0 and isinstance (inbound_node [0 ], str ):
74
+ # TODO(tfmot): The case for the SlicingOpLambda.
75
+ yield [inbound_node ]
76
+ else :
77
+ yield inbound_node
78
+
79
+ def _get_inbound_layer_names (self , layer ):
80
+ """Return all the inbound connection layer names for the layer."""
81
+ inbound_layer_names = []
82
+ for inbound_node in self ._inbound_node_generator (layer ):
83
+ for connection_info in inbound_node :
84
+ # input argument case.
85
+ inbound_layer_names .append (connection_info [0 ])
86
+ # **kwarg argument case.
87
+ inbound_layer_names += [
88
+ value [0 ] for value in connection_info [3 ].items ()]
89
+
90
+ return inbound_layer_names
91
+
71
92
def _get_consuming_layers (self , check_layer ):
72
93
"""Returns all the layers which are out nodes from the layer."""
73
94
consuming_layers = []
95
+ check_layer_name = check_layer ['config' ]['name' ]
74
96
for layer in self ._config ['layers' ]:
75
- for inbound_node in layer ['inbound_nodes' ]:
76
- for connection_info in inbound_node :
77
- if connection_info [0 ] == check_layer ['config' ]['name' ]:
78
- consuming_layers .append (layer )
97
+ if check_layer_name in self ._get_inbound_layer_names (layer ):
98
+ consuming_layers .append (layer )
99
+
79
100
return consuming_layers
80
101
81
102
def _get_output_consumers (self , check_layer ):
@@ -292,11 +313,22 @@ def _replace_functional(self, match_layer_node, replacement_layer_node):
292
313
# replaced layer should equal the original layer.
293
314
294
315
consuming_layers = self ._get_consuming_layers (match_layer_node .layer )
316
+ match_name = match_layer_node .layer ['config' ]['name' ]
317
+ replacement_name = replacement_layer_node .layer ['config' ]['name' ]
318
+
319
+ def _replace_layer_name_for_connection_info (
320
+ connection_info , match_name , replacement_name ):
321
+ if connection_info [0 ] == match_name :
322
+ connection_info [0 ] = replacement_name
323
+ for key in connection_info [3 ]:
324
+ if connection_info [3 ][key ][0 ] == match_name :
325
+ connection_info [3 ][key ][0 ] = replacement_name
326
+
295
327
for consumer in consuming_layers :
296
- for inbound_node in consumer [ 'inbound_nodes' ] :
328
+ for inbound_node in self . _inbound_node_generator ( consumer ) :
297
329
for connection_info in inbound_node :
298
- if connection_info [ 0 ] == match_layer_node . layer [ 'config' ][ 'name' ]:
299
- connection_info [ 0 ] = replacement_layer_node . layer [ 'config' ][ 'name' ]
330
+ _replace_layer_name_for_connection_info (
331
+ connection_info , match_name , replacement_name )
300
332
301
333
output_consumers = self ._get_output_consumers (match_layer_node .layer )
302
334
for output_consumer in output_consumers :
0 commit comments