66from collections import OrderedDict , abc as container_abcs
77from collections import OrderedDict
88import time
9+ from queue import Queue
910import tensorlayerx as tlx
1011import tensorflow as tf
1112from tensorlayerx .nn .layers .utils import (get_variable_with_initializer , random_normal )
@@ -100,8 +101,8 @@ def __init__(self, name=None, act=None, *args, **kwargs):
100101 self ._built = False
101102
102103 # Layer nodes state
103- self ._nodes = []
104104 self ._nodes_fixed = False
105+ self ._build_graph = False
105106
106107 # Layer weight state
107108 self ._all_weights = None
@@ -583,8 +584,17 @@ def str_to_init(self, initializer):
583584 return str2init (initializer )
584585
585586 def node_build (self , * inputs , ** kwargs ):
586- self .forward (* inputs , ** kwargs )
587- return _global_layer_node
587+ # Add nodes only when the composition is needed.
588+ layers = self .layers_and_names (name_prefix = '' )
589+ for layer_name , layer in layers :
590+ if isinstance (layer , Module ):
591+ layer ._build_graph = True
592+
593+ outputs = self .forward (* inputs , ** kwargs )
594+ self .inputs = inputs
595+ self .outputs = outputs
596+ self ._node_by_depth , self ._all_layers = self ._construct_graph ()
597+ return self ._node_by_depth , self ._all_layers
588598
589599 def _add_node (self , input_tensors , output_tensors ):
590600 """Add a ModuleNode for this layer given input_tensors, output_tensors.
@@ -602,33 +612,87 @@ def _add_node(self, input_tensors, output_tensors):
602612 """
603613 inputs_list = tolist (input_tensors )
604614 outputs_list = tolist (output_tensors )
605-
606615 if self .__class__ .__name__ in tlx .layers .inputs .__all__ :
607616 # for InputLayer, there should be no in_nodes
608617 in_nodes = []
609618 in_tensor_idxes = [0 ]
610619 else :
611- in_nodes = [tensor for tensor in inputs_list ]
612- in_tensor_idxes = [idx for idx , tensor in enumerate (inputs_list )]
613- # in_nodes = [tensor._info[0] for tensor in inputs_list]
614- # in_tensor_idxes = [tensor._info[1] for tensor in inputs_list]
620+ in_nodes = [tensor ._info [0 ] for tensor in inputs_list ]
621+ in_tensor_idxes = [tensor ._info [1 ] for tensor in inputs_list ]
615622 node_index = len (_global_layer_node )
616623
617624 new_node = ModuleNode (self , node_index , in_nodes , inputs_list , outputs_list , in_tensor_idxes )
618625 _global_layer_node .append (new_node )
619626 for idx , tensor in enumerate (outputs_list ):
620627 tensor ._info = (new_node , idx )
621628
629+ def _construct_graph (self ):
630+ """construct computation graph for model using ModuleNode object"""
631+ all_layers = []
632+ node_by_depth = []
633+
634+ input_tensors_list = self .inputs if isinstance (self .inputs , list ) else [self .inputs ]
635+
636+ queue_node = Queue ()
637+ # BFS to visit all nodes that should be involved in the computation graph
638+ output_tensors_list = self .outputs if isinstance (self .outputs , list ) else [self .outputs ]
639+ output_nodes = [tensor ._info [0 ] for tensor in output_tensors_list ]
640+
641+ visited_node_names = set ()
642+ for out_node in output_nodes :
643+ if out_node .visited :
644+ continue
645+ queue_node .put (out_node )
646+
647+ while not queue_node .empty ():
648+ cur_node = queue_node .get ()
649+ in_nodes = cur_node .in_nodes
650+
651+ for node in in_nodes :
652+ node .out_nodes .append (cur_node )
653+ if not node .visited :
654+ queue_node .put (node )
655+ node .visited = True
656+ if node .node_name not in visited_node_names :
657+ visited_node_names .add (node .node_name )
658+ # else have multiple layers with the same name
659+ else :
660+ raise ValueError (
661+ 'Layer name \' %s\' has already been used by another layer. Please change the layer name.'
662+ % node .layer .name
663+ )
664+
665+ # construct the computation graph in top-sort order
666+ cur_depth = [tensor [0 ]._info [0 ] for tensor in input_tensors_list ]
667+ next_depth = []
668+ indegrees = {}
669+
670+ visited_layer_names = []
671+ while not len (cur_depth ) == 0 :
672+ node_by_depth .append (cur_depth )
673+ for node in cur_depth :
674+ if node .layer .name not in visited_layer_names :
675+ all_layers .append (node .layer )
676+ visited_layer_names .append (node .layer .name )
677+ for out_node in node .out_nodes :
678+ if out_node .node_name not in indegrees .keys ():
679+ indegrees [out_node .node_name ] = len (out_node .in_nodes )
680+ indegrees [out_node .node_name ] -= 1
681+ if indegrees [out_node .node_name ] == 0 :
682+ next_depth .append (out_node )
683+
684+ cur_depth = next_depth
685+ next_depth = []
686+
687+ return node_by_depth , all_layers
688+
622689
623690class ModuleNode (object ):
624691 """
625692 The class :class:`ModuleNode` class represents a conceptional node for a layer.
626693
627- ModuleNode is used for building static model and it is actually a light weighted
628- wrapper over Layer. Specifically, it is used for building static computational graph
629- (see _construct_graph() in tl.models.Model). In static model, each layer relates to
630- one or more ModuleNode, and the connection relationship between layers is built upon
631- ModuleNode. In addition, ModuleNode eases layer reuse and weights sharing.
694+ ModuleNode is used for building topology and it is actually a light weighted
695+ wrapper over Layer.
632696
633697 Parameters
634698 ----------
@@ -660,14 +724,14 @@ def __init__(self, layer, node_index, in_nodes, in_tensors, out_tensors, in_tens
660724 self .out_nodes = []
661725 self .in_tensors = in_tensors
662726 self .out_tensors = out_tensors
663- self .name = layer .name + "_node_{}" .format (node_index )
727+ self .node_name = layer .name + "_node_{}" .format (node_index )
664728
665729 self .in_tensors_idxes = in_tensor_idxes
666730 self .visited = False
667731
668732 def __call__ (self , inputs , ** kwargs ):
669733 """(1) Forwarding through the layer. (2) Update its input/output tensors."""
670- outputs = self .layer . forward (inputs , ** kwargs )
734+ outputs = self .layer (inputs , ** kwargs )
671735 self .in_tensors = tolist (inputs )
672736 self .out_tensors = tolist (outputs )
673737 return self .out_tensors
0 commit comments