Skip to content

Commit e307282

Browse files
committed
Added TLX Computation Graph
1 parent 226606c commit e307282

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+526
-75
lines changed

examples/basic_tutorials/test_topology.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

tensorlayerx/nn/core/core_mindspore.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,10 @@ def __init__(self, name=None, act=None, *args, **kwargs):
6464

6565
# Layer building state
6666
self._built = False
67+
6768
# Layer nodes state
68-
self._nodes = []
6969
self._nodes_fixed = False
70+
self._build_graph = False
7071

7172
# Layer weight state
7273
self._all_weights = []

tensorlayerx/nn/core/core_paddle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def __init__(self, name=None, act=None, *args, **kwargs):
6565
self._paddle_built = False
6666

6767
# Layer nodes state
68-
self._nodes = []
6968
self._nodes_fixed = False
69+
self._build_graph = False
7070

7171
# Layer weight state
7272
self._all_weights = None

tensorlayerx/nn/core/core_tensorflow.py

Lines changed: 79 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections import OrderedDict, abc as container_abcs
77
from collections import OrderedDict
88
import time
9+
from queue import Queue
910
import tensorlayerx as tlx
1011
import tensorflow as tf
1112
from tensorlayerx.nn.layers.utils import (get_variable_with_initializer, random_normal)
@@ -100,8 +101,8 @@ def __init__(self, name=None, act=None, *args, **kwargs):
100101
self._built = False
101102

102103
# Layer nodes state
103-
self._nodes = []
104104
self._nodes_fixed = False
105+
self._build_graph = False
105106

106107
# Layer weight state
107108
self._all_weights = None
@@ -583,8 +584,17 @@ def str_to_init(self, initializer):
583584
return str2init(initializer)
584585

585586
def node_build(self, *inputs, **kwargs):
586-
self.forward(*inputs, **kwargs)
587-
return _global_layer_node
587+
# Add nodes only when the composition is needed.
588+
layers = self.layers_and_names(name_prefix='')
589+
for layer_name, layer in layers:
590+
if isinstance(layer, Module):
591+
layer._build_graph = True
592+
593+
outputs = self.forward(*inputs, **kwargs)
594+
self.inputs = inputs
595+
self.outputs = outputs
596+
self._node_by_depth, self._all_layers = self._construct_graph()
597+
return self._node_by_depth, self._all_layers
588598

589599
def _add_node(self, input_tensors, output_tensors):
590600
"""Add a ModuleNode for this layer given input_tensors, output_tensors.
@@ -602,33 +612,87 @@ def _add_node(self, input_tensors, output_tensors):
602612
"""
603613
inputs_list = tolist(input_tensors)
604614
outputs_list = tolist(output_tensors)
605-
606615
if self.__class__.__name__ in tlx.layers.inputs.__all__:
607616
# for InputLayer, there should be no in_nodes
608617
in_nodes = []
609618
in_tensor_idxes = [0]
610619
else:
611-
in_nodes = [tensor for tensor in inputs_list]
612-
in_tensor_idxes = [idx for idx, tensor in enumerate(inputs_list)]
613-
# in_nodes = [tensor._info[0] for tensor in inputs_list]
614-
# in_tensor_idxes = [tensor._info[1] for tensor in inputs_list]
620+
in_nodes = [tensor._info[0] for tensor in inputs_list]
621+
in_tensor_idxes = [tensor._info[1] for tensor in inputs_list]
615622
node_index = len(_global_layer_node)
616623

617624
new_node = ModuleNode(self, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes)
618625
_global_layer_node.append(new_node)
619626
for idx, tensor in enumerate(outputs_list):
620627
tensor._info = (new_node, idx)
621628

629+
def _construct_graph(self):
630+
"""construct computation graph for model using ModuleNode object"""
631+
all_layers = []
632+
node_by_depth = []
633+
634+
input_tensors_list = self.inputs if isinstance(self.inputs, list) else [self.inputs]
635+
636+
queue_node = Queue()
637+
# BFS to visit all nodes that should be involved in the computation graph
638+
output_tensors_list = self.outputs if isinstance(self.outputs, list) else [self.outputs]
639+
output_nodes = [tensor._info[0] for tensor in output_tensors_list]
640+
641+
visited_node_names = set()
642+
for out_node in output_nodes:
643+
if out_node.visited:
644+
continue
645+
queue_node.put(out_node)
646+
647+
while not queue_node.empty():
648+
cur_node = queue_node.get()
649+
in_nodes = cur_node.in_nodes
650+
651+
for node in in_nodes:
652+
node.out_nodes.append(cur_node)
653+
if not node.visited:
654+
queue_node.put(node)
655+
node.visited = True
656+
if node.node_name not in visited_node_names:
657+
visited_node_names.add(node.node_name)
658+
# else have multiple layers with the same name
659+
else:
660+
raise ValueError(
661+
'Layer name \'%s\' has already been used by another layer. Please change the layer name.'
662+
% node.layer.name
663+
)
664+
665+
# construct the computation graph in top-sort order
666+
cur_depth = [tensor[0]._info[0] for tensor in input_tensors_list]
667+
next_depth = []
668+
indegrees = {}
669+
670+
visited_layer_names = []
671+
while not len(cur_depth) == 0:
672+
node_by_depth.append(cur_depth)
673+
for node in cur_depth:
674+
if node.layer.name not in visited_layer_names:
675+
all_layers.append(node.layer)
676+
visited_layer_names.append(node.layer.name)
677+
for out_node in node.out_nodes:
678+
if out_node.node_name not in indegrees.keys():
679+
indegrees[out_node.node_name] = len(out_node.in_nodes)
680+
indegrees[out_node.node_name] -= 1
681+
if indegrees[out_node.node_name] == 0:
682+
next_depth.append(out_node)
683+
684+
cur_depth = next_depth
685+
next_depth = []
686+
687+
return node_by_depth, all_layers
688+
622689

623690
class ModuleNode(object):
624691
"""
625692
The class :class:`ModuleNode` class represents a conceptional node for a layer.
626693
627-
ModuleNode is used for building static model and it is actually a light weighted
628-
wrapper over Layer. Specifically, it is used for building static computational graph
629-
(see _construct_graph() in tl.models.Model). In static model, each layer relates to
630-
one or more ModuleNode, and the connection relationship between layers is built upon
631-
ModuleNode. In addition, ModuleNode eases layer reuse and weights sharing.
694+
ModuleNode is used for building topology and it is actually a light weighted
695+
wrapper over Layer.
632696
633697
Parameters
634698
----------
@@ -660,14 +724,14 @@ def __init__(self, layer, node_index, in_nodes, in_tensors, out_tensors, in_tens
660724
self.out_nodes = []
661725
self.in_tensors = in_tensors
662726
self.out_tensors = out_tensors
663-
self.name = layer.name + "_node_{}".format(node_index)
727+
self.node_name = layer.name + "_node_{}".format(node_index)
664728

665729
self.in_tensors_idxes = in_tensor_idxes
666730
self.visited = False
667731

668732
def __call__(self, inputs, **kwargs):
669733
"""(1) Forwarding through the layer. (2) Update its input/output tensors."""
670-
outputs = self.layer.forward(inputs, **kwargs)
734+
outputs = self.layer(inputs, **kwargs)
671735
self.in_tensors = tolist(inputs)
672736
self.out_tensors = tolist(outputs)
673737
return self.out_tensors

tensorlayerx/nn/core/core_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def __init__(self, name=None, act=None, *args, **kwargs):
6161
self._built = False
6262

6363
# Layer nodes state
64-
self._nodes = []
6564
self._nodes_fixed = False
65+
self._build_graph = False
6666

6767
# Layer weight state
6868
self._all_weights = None

tensorlayerx/nn/layers/Transformer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ def forward(self, q, k=None, v=None, attn_mask=None, key_padding_mask=None):
175175

176176
attn_output, attn_output_weights = self.multiheadattention(q, k, v, attn_mask, key_padding_mask)
177177

178+
if not self._nodes_fixed and self._build_graph:
179+
self._add_node([q, k, v, attn_mask, key_padding_mask], [attn_output, attn_output_weights])
180+
self._nodes_fixed = True
178181
return attn_output, attn_output_weights
179182

180183

@@ -308,6 +311,9 @@ def forward(
308311
memory_key_padding_mask=memory_key_padding_mask
309312
)
310313

314+
if not self._nodes_fixed and self._build_graph:
315+
self._add_node([src, tgt, src_mask, tgt_mask, memory_mask, src_key_padding_mask, tgt_key_padding_mask, memory_key_padding_mask], output)
316+
self._nodes_fixed = True
311317
return output
312318

313319
def generate_square_subsequent_mask(self, length):
@@ -389,6 +395,9 @@ def forward(self, src, mask=None, src_key_padding_mask=None):
389395
if self.norm is not None:
390396
output = self.norm(output)
391397

398+
if not self._nodes_fixed and self._build_graph:
399+
self._add_node([src, mask, src_key_padding_mask], output)
400+
self._nodes_fixed = True
392401
return output
393402

394403

@@ -461,6 +470,9 @@ def forward(
461470
if self.norm is not None:
462471
output = self.norm(output)
463472

473+
if not self._nodes_fixed and self._build_graph:
474+
self._add_node([tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask], output)
475+
self._nodes_fixed = True
464476
return output
465477

466478

@@ -549,13 +561,19 @@ def forward(self, src, src_mask=None, src_key_padding_mask=None):
549561
the mask for the src keys per batch.
550562
551563
"""
564+
565+
inputs = [src, src_mask, src_key_padding_mask]
566+
552567
src1 = self.self_attn(src, src, src, src_mask, src_key_padding_mask)[0]
553568
src = src + self.dropout1(src1)
554569
src = self.norm1(src)
555570
src1 = self.linear2(self.dropout2(self.act(self.linear1(src))))
556571
src = src + self.dropout3(src1)
557572
src = self.norm2(src)
558573

574+
if not self._nodes_fixed and self._build_graph:
575+
self._add_node(inputs, src)
576+
self._nodes_fixed = True
559577
return src
560578

561579

@@ -650,6 +668,8 @@ def forward(
650668
the mask for the memory keys per batch.
651669
652670
"""
671+
inputs = [tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask]
672+
653673
tgt1 = self.self_attn(tgt, tgt, tgt, tgt_mask, tgt_key_padding_mask)[0]
654674
tgt = tgt + self.dropout1(tgt1)
655675
tgt = self.norm1(tgt)
@@ -659,5 +679,9 @@ def forward(
659679
tgt1 = self.linear2(self.dropout3(self.act(self.linear1(tgt))))
660680
tgt = tgt + self.dropout3(tgt1)
661681
tgt = self.norm3(tgt)
682+
683+
if not self._nodes_fixed and self._build_graph:
684+
self._add_node(inputs, tgt)
685+
self._nodes_fixed = True
662686
return tgt
663687

0 commit comments

Comments
 (0)