Skip to content

Commit c131090

Browse files
committed
init topology
1 parent c92a638 commit c131090

File tree

4 files changed

+131
-18
lines changed

4 files changed

+131
-18
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import tensorlayerx as tlx
2+
from tensorlayerx.nn import Module
3+
from tensorlayerx.nn import Linear, Dropout
4+
5+
6+
class CustomModel(Module):
7+
8+
def __init__(self):
9+
super(CustomModel, self).__init__()
10+
self.linear1 = Linear(out_features=800, act=tlx.ReLU, in_features=784)
11+
self.linear2 = Linear(out_features=800, act=tlx.ReLU, in_features=800)
12+
self.linear3 = Linear(out_features=10, act=tlx.ReLU, in_features=800)
13+
14+
def forward(self, x, foo=None):
15+
z = self.linear1(x)
16+
z = self.linear2(z)
17+
out = self.linear3(z)
18+
# if foo is not None:
19+
# out = tlx.relu(out)
20+
return out
21+
22+
model = CustomModel()
23+
24+
layer_node = model.node_build(tlx.nn.Input(shape=(3, 784)))
25+
for node in layer_node:
26+
print(node.node_index)
27+
print(node.name)

tensorlayerx/nn/core/common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,3 +437,12 @@ def weight_reshape(weight, reshape=False):
437437
if len(weight.shape) == 5:
438438
weight = np.moveaxis(weight, (3, 4), (1, 0))
439439
return weight
440+
441+
def tolist(tensors):
442+
if isinstance(tensors, list) or isinstance(tensors, tuple):
443+
ntensors = list()
444+
for t in tensors:
445+
ntensors += tolist(t)
446+
return ntensors
447+
else:
448+
return [tensors]

tensorlayerx/nn/core/core_tensorflow.py

Lines changed: 92 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#! /usr/bin/python
22
# -*- coding: utf-8 -*-
33

4-
from .common import str2act, str2init
4+
from .common import str2act, str2init, tolist
55
from .common import _save_weights, _load_weights, _save_standard_weights_dict, _load_standard_weights_dict
66
from collections import OrderedDict
77
import time
@@ -12,6 +12,7 @@
1212
__all__ = ['Module', 'Sequential', 'ModuleList']
1313

1414
_global_layer_name_dict = {}
15+
_global_layer_node = []
1516

1617

1718
class Module(object):
@@ -322,23 +323,6 @@ def insert_param_to_layer(self, param_name, param, check_name=True):
322323
except:
323324
pass
324325

325-
def _add_node(self, input_tensors, output_tensors):
326-
"""Add a LayerNode for this layer given input_tensors, output_tensors.
327-
328-
WARINING: This function should not be called from outside, it should only be called
329-
in layer.__call__ when building static model.
330-
331-
Parameters
332-
----------
333-
input_tensors : Tensor or a list of tensors
334-
Input tensors to this layer.
335-
output_tensors : Tensor or a list of tensors
336-
Output tensors to this layer.
337-
338-
"""
339-
340-
raise NotImplementedError
341-
342326
@property
343327
def create_time(self):
344328
return self._create_time
@@ -597,6 +581,96 @@ def init_build(self, *inputs, **kwargs):
597581
def str_to_init(self, initializer):
598582
return str2init(initializer)
599583

584+
def node_build(self, *inputs, **kwargs):
585+
self.forward(*inputs, **kwargs)
586+
return _global_layer_node
587+
588+
def _add_node(self, input_tensors, output_tensors):
589+
"""Add a ModuleNode for this layer given input_tensors, output_tensors.
590+
591+
This function should not be called from outside, it should only be called
592+
in __call__ when building static model.
593+
594+
Parameters
595+
----------
596+
input_tensors : Tensor or a list of tensors
597+
Input tensors to this layer.
598+
output_tensors : Tensor or a list of tensors
599+
Output tensors to this layer.
600+
601+
"""
602+
inputs_list = tolist(input_tensors)
603+
outputs_list = tolist(output_tensors)
604+
605+
if self.__class__.__name__ in tlx.layers.inputs.__all__:
606+
# for InputLayer, there should be no in_nodes
607+
in_nodes = []
608+
in_tensor_idxes = [0]
609+
else:
610+
in_nodes = [tensor for tensor in inputs_list]
611+
in_tensor_idxes = [idx for idx, tensor in enumerate(inputs_list)]
612+
# in_nodes = [tensor._info[0] for tensor in inputs_list]
613+
# in_tensor_idxes = [tensor._info[1] for tensor in inputs_list]
614+
node_index = len(_global_layer_node)
615+
616+
new_node = ModuleNode(self, node_index, in_nodes, inputs_list, outputs_list, in_tensor_idxes)
617+
_global_layer_node.append(new_node)
618+
for idx, tensor in enumerate(outputs_list):
619+
tensor._info = (new_node, idx)
620+
621+
622+
class ModuleNode(object):
623+
"""
624+
The class :class:`ModuleNode` class represents a conceptional node for a layer.
625+
626+
ModuleNode is used for building static model and it is actually a light weighted
627+
wrapper over Layer. Specifically, it is used for building static computational graph
628+
(see _construct_graph() in tl.models.Model). In static model, each layer relates to
629+
one or more ModuleNode, and the connection relationship between layers is built upon
630+
ModuleNode. In addition, ModuleNode eases layer reuse and weights sharing.
631+
632+
Parameters
633+
----------
634+
layer : tl.layers.Layer
635+
A tl layer that wants to create a node.
636+
node_index : int
637+
Index of this node in layer._nodes.
638+
in_nodes :a list of ModuleNode
639+
Father nodes to this node.
640+
in_tensors : a list of tensors
641+
Input tensors to this node.
642+
out_tensors : a list of tensors
643+
Output tensors to this node.
644+
in_tensor_idxes : a list of int
645+
Indexes of each input tensor in its corresponding node's out_tensors.
646+
647+
Methods
648+
---------
649+
__init__()
650+
Initializing the ModuleNode.
651+
__call__()
652+
(1) Forwarding through the layer. (2) Update its input/output tensors.
653+
"""
654+
655+
def __init__(self, layer, node_index, in_nodes, in_tensors, out_tensors, in_tensor_idxes):
656+
self.layer = layer
657+
self.node_index = node_index
658+
self.in_nodes = in_nodes
659+
self.out_nodes = []
660+
self.in_tensors = in_tensors
661+
self.out_tensors = out_tensors
662+
self.name = layer.name + "_node_{}".format(node_index)
663+
664+
self.in_tensors_idxes = in_tensor_idxes
665+
self.visited = False
666+
667+
def __call__(self, inputs, **kwargs):
668+
"""(1) Forwarding through the layer. (2) Update its input/output tensors."""
669+
outputs = self.layer.forward(inputs, **kwargs)
670+
self.in_tensors = tolist(inputs)
671+
self.out_tensors = tolist(outputs)
672+
return self.out_tensors
673+
600674

601675
class Sequential(Module):
602676
"""

tensorlayerx/nn/layers/linear/base_linear.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,5 +114,8 @@ def forward(self, inputs):
114114
z = self.bias_add(z, self.b)
115115
if self.act_init_flag:
116116
z = self.act(z)
117+
118+
if not self._nodes_fixed:
119+
self._add_node(inputs, z)
117120
return z
118121

0 commit comments

Comments
 (0)