Skip to content

Commit 5ff8686

Browse files
committed
Fix processing activation
1 parent 7593500 commit 5ff8686

File tree

6 files changed

+94
-58
lines changed

6 files changed

+94
-58
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
import tensorlayerx as tlx
5+
from tensorlayerx.nn import Module
6+
from tensorlayerx.nn import Linear, Conv2d, BatchNorm2d, MaxPool2d, Flatten
7+
8+
class CNN(Module):
9+
10+
def __init__(self):
11+
super(CNN, self).__init__()
12+
# weights init
13+
W_init = tlx.nn.initializers.truncated_normal(stddev=5e-2)
14+
W_init2 = tlx.nn.initializers.truncated_normal(stddev=0.04)
15+
b_init2 = tlx.nn.initializers.constant(value=0.1)
16+
17+
self.conv1 = Conv2d(64, (5, 5), (1, 1), padding='SAME', W_init=W_init, b_init=None, name='conv1', in_channels=3, act=tlx.nn.ReLU)
18+
self.bn = BatchNorm2d(num_features=64, act=tlx.nn.ReLU)
19+
self.maxpool1 = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1')
20+
21+
self.conv2 = Conv2d(
22+
64, (5, 5), (1, 1), padding='SAME', act=tlx.nn.ReLU, W_init=W_init, b_init=None, name='conv2', in_channels=64
23+
)
24+
self.maxpool2 = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2')
25+
26+
self.flatten = Flatten(name='flatten')
27+
self.linear1 = Linear(384, act=tlx.nn.ReLU, W_init=W_init2, b_init=b_init2, name='linear1relu', in_features=2304)
28+
self.linear2 = Linear(192, act=tlx.nn.ReLU, W_init=W_init2, b_init=b_init2, name='linear2relu', in_features=384)
29+
self.linear3 = Linear(10, act=None, W_init=W_init2, name='output1', in_features=192)
30+
self.linear4 = Linear(20, act=None, W_init=W_init2, name='output2', in_features=192)
31+
self.concat = tlx.nn.Concat(name='concat')
32+
33+
def forward(self, x):
34+
z = self.conv1(x)
35+
z = self.bn(z)
36+
z = self.maxpool1(z)
37+
z = self.conv2(z)
38+
z = self.maxpool2(z)
39+
z = self.flatten(z)
40+
z = self.linear1(z)
41+
z = self.linear2(z)
42+
z1 = self.linear3(z)
43+
z2 = self.linear4(z)
44+
z = self.concat([z1, z2])
45+
return z
46+
47+
model = CNN()
48+
inputs = tlx.nn.Input(shape=(3, 24, 24, 3))
49+
outputs = model(inputs)
50+
51+
node_by_depth, all_layers = model.build_graph(inputs)
52+
53+
for depth, nodes in enumerate(node_by_depth):
54+
if depth == 0:
55+
if isinstance(inputs, list):
56+
assert len(inputs) == len(nodes)
57+
for idx, node in enumerate(nodes):
58+
print(node.node_name, node.layer)
59+
else:
60+
print(nodes[0].node_name, nodes[0].layer)
61+
else:
62+
for node in nodes:
63+
print(node.node_name, node.layer)

tensorlayerx/nn/core/common.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,28 @@ def str2act(act):
8787
return _act_dict[act]
8888

8989

90+
def processing_act(act):
91+
# Processing strings as input, activation functions without parameters。
92+
if isinstance(act, str):
93+
str_act = str2act(act)
94+
if act:
95+
# Processing strings as input, activation functions with parameters。
96+
if isinstance(act, str) and (len(act) > 5 and act[0:5] == "lrelu" or
97+
len(act) > 10 and act[0:10] == "leaky_relu"):
98+
out_act = str_act
99+
elif isinstance(act, str):
100+
out_act = str_act()
101+
# Processing classes or functions as input, activation functions without parameters
102+
elif type(act) == type(tlx.nn.ReLU):
103+
out_act = act()
104+
# Processing class or function as input, activation function with parameters
105+
else:
106+
out_act = act
107+
else:
108+
out_act = act
109+
return out_act
110+
111+
90112
def _save_weights(net, file_path, format=None):
91113
"""Input file_path, save model weights into a file of given format.
92114
Use net.load_weights() to restore.

tensorlayerx/nn/core/core_mindspore.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#! /usr/bin/python
22
# -*- coding: utf-8 -*-
33

4-
from .common import check_parameter, str2act, str2init, random_normal, tolist, construct_graph, ModuleNode, select_attrs
4+
from .common import check_parameter, processing_act, str2init, random_normal, tolist, construct_graph, ModuleNode, select_attrs
55
from .common import _save_weights, _load_weights, _save_standard_weights_dict, _load_standard_weights_dict
66
from mindspore.nn import Cell
77
import tensorlayerx as tlx
@@ -49,19 +49,7 @@ def __init__(self, name=None, act=None, *args, **kwargs):
4949

5050
self.name = name
5151

52-
if isinstance(act, str):
53-
str_act = str2act(act)
54-
55-
if act:
56-
if isinstance(act, str) and (len(act) > 5 and act[0:5] == "lrelu" or
57-
len(act) > 10 and act[0:10] == "leaky_relu"):
58-
self.act = str_act
59-
elif isinstance(act, str):
60-
self.act = str_act()
61-
else:
62-
self.act = act()
63-
else:
64-
self.act = act
52+
self.act = processing_act(act)
6553

6654
# Layer building state
6755
self._built = False

tensorlayerx/nn/core/core_paddle.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding: utf-8 -*-
33

44
import copy, six
5-
from .common import check_parameter, str2act, str2init, tolist, construct_graph, ModuleNode
5+
from .common import check_parameter, processing_act, str2init
66
from .common import _save_weights, _load_weights, _save_standard_weights_dict, _load_standard_weights_dict
77
from paddle.fluid import framework
88
from paddle.fluid.dygraph import Layer
@@ -49,19 +49,7 @@ def __init__(self, name=None, act=None, *args, **kwargs):
4949

5050
self.name = name
5151

52-
if isinstance(act, str):
53-
str_act = str2act(act)
54-
55-
if act:
56-
if isinstance(act, str) and (len(act) > 5 and act[0:5] == "lrelu" or
57-
len(act) > 10 and act[0:10] == "leaky_relu"):
58-
self.act = str_act
59-
elif isinstance(act, str):
60-
self.act = str_act()
61-
else:
62-
self.act = act()
63-
else:
64-
self.act = act
52+
self.act = processing_act(act)
6553

6654
# Layer building state
6755
self._built = False

tensorlayerx/nn/core/core_tensorflow.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#! /usr/bin/python
22
# -*- coding: utf-8 -*-
33

4-
from .common import check_parameter, str2act, str2init, tolist, construct_graph, ModuleNode, select_attrs
4+
from .common import check_parameter, processing_act, str2init, tolist, construct_graph, ModuleNode, select_attrs
55
from .common import _save_weights, _load_weights, _save_standard_weights_dict, _load_standard_weights_dict
66
from collections import OrderedDict, abc as container_abcs
77
from collections import OrderedDict
@@ -82,19 +82,7 @@ def __init__(self, name=None, act=None, *args, **kwargs):
8282

8383
self.name = name
8484

85-
if isinstance(act, str):
86-
str_act = str2act(act)
87-
88-
if act:
89-
if isinstance(act, str) and (len(act) > 5 and act[0:5] == "lrelu" or
90-
len(act) > 10 and act[0:10] == "leaky_relu"):
91-
self.act = str_act
92-
elif isinstance(act, str):
93-
self.act = str_act()
94-
else:
95-
self.act = act()
96-
else:
97-
self.act = act
85+
self.act = processing_act(act)
9886

9987
# Layer building state
10088
self._built = False
@@ -587,8 +575,7 @@ def check_param(self, param, dim='2d'):
587575

588576
def build_graph(self, *inputs, **kwargs):
589577
# Add nodes only when the composition is needed.
590-
layers = self.layers_and_names(name_prefix='')
591-
for layer_name, layer in layers:
578+
for layer_name, layer in self._layers.items():
592579
if isinstance(layer, Module):
593580
layer._build_graph = True
594581
self.set_eval()

tensorlayerx/nn/core/core_torch.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding: utf-8 -*-
33

44
from torch.nn import Module as T_Module
5-
from .common import check_parameter, str2act, str2init, tolist, construct_graph, ModuleNode, select_attrs
5+
from .common import check_parameter, processing_act, str2init, tolist, construct_graph, ModuleNode, select_attrs
66
from .common import _save_weights, _load_weights, _save_standard_weights_dict, _load_standard_weights_dict
77
from torch.nn.parameter import Parameter
88
from collections import OrderedDict
@@ -46,19 +46,7 @@ def __init__(self, name=None, act=None, *args, **kwargs):
4646

4747
self.name = name
4848

49-
if isinstance(act, str):
50-
str_act = str2act(act)
51-
52-
if act:
53-
if isinstance(act, str) and (len(act) > 5 and act[0:5] == "lrelu" or
54-
len(act) > 10 and act[0:10] == "leaky_relu"):
55-
self.act = str_act
56-
elif isinstance(act, str):
57-
self.act = str_act()
58-
else:
59-
self.act = act()
60-
else:
61-
self.act = act
49+
self.act = processing_act(act)
6250

6351
# Layer building state
6452
self._built = False

0 commit comments

Comments
 (0)