Skip to content

Commit 4493775

Browse files
committed
Fix torch backend manual naming and added layer names cannot repeated.
1 parent b3960f7 commit 4493775

24 files changed

+198
-137
lines changed

tensorlayerx/backend/ops/tensorflow_nn.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,37 @@ def padding_format(padding):
4040
raise Exception("Unsupported padding: " + str(padding))
4141
return padding
4242

43+
def channel_format(data_format, dim='2d'):
44+
if dim == '1d':
45+
if data_format in ["channels_last", "NWC", 'NLC']:
46+
data_format = "NWC"
47+
elif data_format in ["channels_first", "NCW", 'NCL']:
48+
data_format = "NCW"
49+
elif data_format == None:
50+
data_format = None
51+
else:
52+
raise Exception("Unsupported data format: " + str(data_format))
53+
elif dim == '2d':
54+
if data_format in ["channels_last", "NHWC"]:
55+
data_format = "NHWC"
56+
elif data_format in ["channels_first", "NCHW"]:
57+
data_format = "NCHW"
58+
elif data_format == None:
59+
data_format = None
60+
else:
61+
raise Exception("Unsupported data format: " + str(data_format))
62+
elif dim == '3d':
63+
if data_format in ['channels_last', 'NDHWC']:
64+
data_format = 'NDHWC'
65+
elif data_format in ['channels_first', 'NCDHW']:
66+
data_format = 'NCDHW'
67+
elif data_format == None:
68+
data_format = None
69+
else:
70+
raise Exception("Unsupported data format: " + str(data_format))
71+
else:
72+
raise Exception("dim must be '1d', '2d', '3d'.")
73+
return data_format
4374

4475
def preprocess_padding(padding, dim='2d', data_format='NHWC'):
4576
# When explicit padding is used and data_format is "NHWC",
@@ -88,8 +119,6 @@ def check_padding(padding, dim='2d'):
88119
raise RuntimeError("expected padding to be a single integer value or a list of 3 values to match the convolution dimensions.")
89120

90121

91-
92-
93122
def preprocess_1d_format(data_format, padding):
94123
"""
95124
Checks that the 1-D dataformat format correspond format.
@@ -105,14 +134,7 @@ def preprocess_1d_format(data_format, padding):
105134
-------
106135
str "NWC" or "NCW" and "SAME" or "VALID"
107136
"""
108-
if data_format in ["channels_last", "NWC", 'NLC']:
109-
data_format = "NWC"
110-
elif data_format in ["channels_first", "NCW", 'NCL']:
111-
data_format = "NCW"
112-
elif data_format == None:
113-
data_format = None
114-
else:
115-
raise Exception("Unsupported data format: " + str(data_format))
137+
data_format = channel_format(data_format, dim='1d')
116138
padding = padding_format(padding)
117139
return data_format, padding
118140

@@ -133,14 +155,7 @@ def preprocess_2d_format(data_format, padding):
133155
str "NHWC" or "NCHW" and "SAME" or "VALID"
134156
"""
135157

136-
if data_format in ["channels_last", "NHWC"]:
137-
data_format = "NHWC"
138-
elif data_format in ["channels_first", "NCHW"]:
139-
data_format = "NCHW"
140-
elif data_format == None:
141-
data_format = None
142-
else:
143-
raise Exception("Unsupported data format: " + str(data_format))
158+
data_format = channel_format(data_format, dim='2d')
144159
padding = padding_format(padding)
145160
return data_format, padding
146161

@@ -161,14 +176,7 @@ def preprocess_3d_format(data_format, padding):
161176
str "NDHWC" or "NCDHW" and "SAME" or "VALID"
162177
"""
163178

164-
if data_format in ['channels_last', 'NDHWC']:
165-
data_format = 'NDHWC'
166-
elif data_format in ['channels_first', 'NCDHW']:
167-
data_format = 'NCDHW'
168-
elif data_format == None:
169-
data_format = None
170-
else:
171-
raise Exception("Unsupported data format: " + str(data_format))
179+
data_format = channel_format(data_format, dim='3d')
172180
padding = padding_format(padding)
173181
return data_format, padding
174182

@@ -868,10 +876,11 @@ def __init__(self, ksize, strides, padding, data_format=None):
868876
self.padding = "VALID"
869877

870878
def __call__(self, inputs):
879+
data_format = channel_format(self.data_format, str(len(inputs.shape) - 2) + 'd')
871880
if self.padding_value is not None:
872881
inputs = tf.pad(inputs, self.padding_value)
873882
outputs = tf.nn.avg_pool(
874-
input=inputs, ksize=self.ksize, strides=self.strides, padding=self.padding, data_format=self.data_format
883+
input=inputs, ksize=self.ksize, strides=self.strides, padding=self.padding, data_format=data_format
875884
)
876885
return outputs
877886

tensorlayerx/nn/core/core_tensorflow.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ def __init__(self, name=None, act=None, *args, **kwargs):
102102
# Layer training state
103103
self.is_train = True
104104

105+
# weights check state
106+
self._check = False
107+
105108
def extend_repr(self):
106109
"""
107110
Sets the extended representation of the Module.
@@ -160,11 +163,23 @@ def __setattr__(self, name, value):
160163
object.__setattr__(self, name, value)
161164

162165
def __call__(self, inputs, *args, **kwargs):
166+
if self._check == False:
167+
self.train_weights_check()
168+
self._check = True
163169

164170
output = self.forward(inputs, *args, **kwargs)
165-
166171
return output
167172

173+
def train_weights_check(self):
174+
_param_name = []
175+
for w in self.trainable_weights:
176+
if w.name not in _param_name:
177+
_param_name.append(w.name)
178+
else:
179+
raise Exception("parameter name [{}] have be been used. "
180+
"In training, the name of layer can't be same."
181+
"Please check the layers name".format(w.name))
182+
168183
def forward(self, *inputs, **kwargs):
169184
raise Exception("The forward method must be implemented by inherited class")
170185

tensorlayerx/nn/core/core_torch.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .common import check_parameter, processing_act, str2init, tolist, construct_graph, ModuleNode, select_attrs
66
from .common import _save_weights, _load_weights, _save_standard_weights_dict, _load_standard_weights_dict
77
from torch.nn.parameter import Parameter
8-
from torch._C import _disabled_torch_function_impl
8+
from typing import Any, Callable
99
import torch
1010
import operator
1111
from itertools import islice
@@ -67,6 +67,10 @@ def __init__(self, name=None, act=None, *args, **kwargs):
6767
# layer forward state
6868
self._forward_state = False
6969

70+
# weights check state
71+
self._check = False
72+
73+
7074
def set_train(self, mode=True):
7175
if not isinstance(mode, bool):
7276
raise ValueError("training mode is expected to be boolean")
@@ -85,8 +89,6 @@ def forward(self, *inputs, **kwargs):
8589
raise Exception("The forward method must be implemented by inherited class")
8690

8791
def _get_weights(self, var_name, shape, init=None, trainable=True, transposed=None, order=False):
88-
var_name = self.name + "/" + var_name
89-
9092
if order:
9193
w_tmp = Parameter(init(shape), requires_grad=trainable)
9294
return w_tmp
@@ -103,8 +105,39 @@ def _get_weights(self, var_name, shape, init=None, trainable=True, transposed=No
103105
# TODO paramters name should be add
104106
_param = init(shape)
105107
param = Parameter(_param, requires_grad=trainable)
108+
self.var_name = var_name
106109
return param
107110

111+
def _call_impl_tlx(self, *input, **kwargs):
112+
if self._check == False:
113+
_param_name = []
114+
for name, param in self.named_parameters(recurse=True):
115+
if name not in _param_name:
116+
_param_name.append(name)
117+
else:
118+
raise Exception("parameter name [{}] have be been used. "
119+
"In training, the name of layer can't be same."
120+
"Please check the layers name".format(name))
121+
self._check = True
122+
123+
result = self._call_impl(*input, **kwargs)
124+
return result
125+
126+
__call__: Callable[..., Any] = _call_impl_tlx
127+
128+
def _named_members(self, get_members_fn, prefix='', recurse=True):
129+
r"""Helper method for yielding various names + members of modules."""
130+
memo = set()
131+
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
132+
for module_prefix, module in modules:
133+
members = get_members_fn(module)
134+
for k, v in members:
135+
if v is None or v in memo:
136+
continue
137+
memo.add(v)
138+
name = module.name + '/' + k
139+
yield name, v
140+
108141
@property
109142
def all_weights(self):
110143
if self._all_weights is not None and len(self._all_weights) > 0:

tensorlayerx/nn/layers/Transformer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,16 @@ def __repr__(self):
103103
def build(self, inputs_shape):
104104
bias_init = tlx.nn.initializers.zeros()
105105
weight_init = tlx.nn.initializers.XavierNormal()
106-
self.q_proj_weight = self._get_weights(
106+
self.q_weight = self._get_weights(
107107
'q_weight', shape=(self.embed_dim, self.embed_dim), init=weight_init, order=True
108108
)
109-
self.k_proj_weight = self._get_weights(
109+
self.k_weight = self._get_weights(
110110
'k_weight', shape=(self.embed_dim, self.kdim), init=weight_init, order=True
111111
)
112-
self.v_proj_weight = self._get_weights(
112+
self.v_weight = self._get_weights(
113113
'v_weight', shape=(self.embed_dim, self.vdim), init=weight_init, order=True
114114
)
115-
self.out_proj_weight = self._get_weights(
115+
self.out_weight = self._get_weights(
116116
'out_weight', shape=(self.embed_dim, self.embed_dim), init=weight_init, order=True
117117
)
118118
self.q_bias = None
@@ -127,8 +127,8 @@ def build(self, inputs_shape):
127127

128128
self.multiheadattention = tlx.ops.multiheadattention(
129129
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, batch_first=self.batch_first,
130-
need_weights=self.need_weights, q_weight=self.q_proj_weight, k_weight=self.k_proj_weight,
131-
v_weight=self.v_proj_weight, out_weight=self.out_proj_weight, q_bias=self.q_bias, k_bias=self.k_bias,
130+
need_weights=self.need_weights, q_weight=self.q_weight, k_weight=self.k_weight,
131+
v_weight=self.v_weight, out_weight=self.out_weight, q_bias=self.q_bias, k_bias=self.k_bias,
132132
v_bias=self.v_bias, out_bias=self.out_bias, train=self.is_train
133133
)
134134

tensorlayerx/nn/layers/activation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def build(self, inputs_shape):
7373
elif dim < 3:
7474
w_shape = (self.num_parameters, )
7575

76-
self.alpha_var = self._get_weights("alpha", shape=w_shape, init=tlx.initializers.constant(value=self.init))
76+
self.alpha = self._get_weights("alpha", shape=w_shape, init=tlx.initializers.constant(value=self.init))
7777
self.prelu = tlx.ops.PReLU(data_format = self.data_format)
7878

7979
def forward(self, inputs):
@@ -83,7 +83,7 @@ def forward(self, inputs):
8383
self._built = True
8484
self._forward_state = True
8585

86-
output = self.prelu(inputs, self.alpha_var)
86+
output = self.prelu(inputs, self.alpha)
8787

8888
if not self._nodes_fixed and self._build_graph:
8989
self._add_node(inputs, output)
@@ -186,7 +186,7 @@ def build(self, inputs_shape):
186186
w_shape = (1, self.in_channels, 1, 1, 1)
187187
else:
188188
raise Exception("Dim should be equal to 1, 2 or 3")
189-
self.alpha_var = self._get_weights("alpha", shape=w_shape, init=self.a_init)
189+
self.alpha = self._get_weights("alpha", shape=w_shape, init=self.a_init)
190190
self.sigmoid = tlx.ops.Sigmoid()
191191
self.relu = tlx.ops.ReLU()
192192

@@ -197,7 +197,7 @@ def forward(self, inputs):
197197
self._built = True
198198
self._forward_state = True
199199

200-
alpha_var_constrained = self.sigmoid(self.alpha_var)
200+
alpha_var_constrained = self.sigmoid(self.alpha)
201201
pos = self.relu(inputs)
202202
pos_6 = -self.relu(inputs - 6)
203203
neg = -alpha_var_constrained * self.relu(-inputs)

tensorlayerx/nn/layers/convolution/binary_conv.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,11 @@ def build(self, inputs_shape):
114114

115115
self.filter_shape = (self.kernel_size[0], self.kernel_size[1], self.in_channels, self.out_channels)
116116

117-
self.W = self._get_weights("filters", shape=self.filter_shape, init=self.W_init)
117+
self.filters = self._get_weights("filters", shape=self.filter_shape, init=self.W_init)
118118

119119
self.b_init_flag = False
120120
if self.b_init:
121-
self.b = self._get_weights("biases", shape=(self.out_channels, ), init=self.b_init)
121+
self.biases = self._get_weights("biases", shape=(self.out_channels, ), init=self.b_init)
122122
self.bias_add = tlx.ops.BiasAdd(self.data_format)
123123
self.b_init_flag = True
124124

@@ -143,10 +143,10 @@ def forward(self, inputs):
143143
self._built = True
144144
self._forward_state = True
145145

146-
outputs = self.binaryconv2d(inputs, self.W)
146+
outputs = self.binaryconv2d(inputs, self.filters)
147147

148148
if self.b_init_flag:
149-
outputs = self.bias_add(outputs, self.b)
149+
outputs = self.bias_add(outputs, self.biases)
150150
if self.act_init_flag:
151151
outputs = self.act(outputs)
152152

tensorlayerx/nn/layers/convolution/deformable_conv.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,10 @@ def build(self, inputs_shape):
146146

147147
self.filter_shape = (1, 1, self.kernel_n, self.in_channels, self.out_channels)
148148

149-
self.W = self._get_weights("W_deformableconv2d", shape=self.filter_shape, init=self.W_init)
149+
self.W_deformableconv2d = self._get_weights("W_deformableconv2d", shape=self.filter_shape, init=self.W_init)
150150

151151
if self.b_init:
152-
self.b = self._get_weights("b_deformableconv2d", shape=(self.out_channels, ), init=self.b_init)
152+
self.b_deformableconv2d = self._get_weights("b_deformableconv2d", shape=(self.out_channels, ), init=self.b_init)
153153

154154
self.conv3d = tlx.ops.Conv3D(strides=[1, 1, 1, 1, 1], padding='VALID')
155155
self.bias_add = tlx.ops.BiasAdd()
@@ -166,12 +166,12 @@ def forward(self, inputs):
166166
grid_offset = self.grid_offset
167167

168168
input_deform = self._tf_batch_map_offsets(inputs, offset, grid_offset)
169-
outputs = self.conv3d(input=input_deform, filters=self.W)
169+
outputs = self.conv3d(input=input_deform, filters=self.W_deformableconv2d)
170170
outputs = tlx.ops.reshape(
171171
tensor=outputs, shape=[outputs.get_shape()[0], self.input_h, self.input_w, self.out_channels]
172172
)
173173
if self.b_init:
174-
outputs = self.bias_add(outputs, self.b)
174+
outputs = self.bias_add(outputs, self.b_deformableconv2d)
175175
if self.act:
176176
outputs = self.act(outputs)
177177
return outputs

tensorlayerx/nn/layers/convolution/depthwise_conv.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,15 @@ def build(self, inputs_shape):
137137
self.filter_shape = (self.kernel_size[0], self.kernel_size[1], self.in_channels, 1)
138138

139139
if BACKEND in ['tensorflow', 'mindspore']:
140-
self.W = self._get_weights("filters", shape=self.filter_shape, init=self.W_init, transposed=True)
141-
self.point_W = None
140+
self.filters = self._get_weights("filters", shape=self.filter_shape, init=self.W_init, transposed=True)
141+
self.point_filter = None
142142
# TODO The number of parameters on multiple backends is not equal.
143143
# TODO It might be better to use deepwise convolution and pointwise convolution for other backends as well.
144144
if BACKEND in ['paddle', 'torch']:
145145
self.filter_depthwise = (self.in_channels, 1, self.kernel_size[0], self.kernel_size[1])
146146
self.filter_pointwise = (self.in_channels * self.depth_multiplier, self.in_channels, 1, 1)
147-
self.W = self._get_weights("filters", shape=self.filter_depthwise, init=self.W_init, order=True)
148-
self.point_W = self._get_weights("point_filter", shape=self.filter_pointwise, init=self.W_init, order=True)
147+
self.filters = self._get_weights("filters", shape=self.filter_depthwise, init=self.W_init, order=True)
148+
self.point_filter = self._get_weights("point_filter", shape=self.filter_pointwise, init=self.W_init, order=True)
149149

150150
self.depthwise_conv2d = tlx.ops.DepthwiseConv2d(
151151
strides=self._strides, padding=self.padding, data_format=self.data_format, dilations=self._dilation,
@@ -169,7 +169,7 @@ def forward(self, inputs):
169169
self._built = True
170170
self._forward_state = True
171171

172-
outputs = self.depthwise_conv2d(input=inputs, filter=self.W, point_filter=self.point_W)
172+
outputs = self.depthwise_conv2d(input=inputs, filter=self.filters, point_filter=self.point_filter)
173173
if self.b_init_flag:
174174
outputs = self.bias_add(outputs, self.b)
175175
if self.act_init_flag:

tensorlayerx/nn/layers/convolution/dorefa_conv.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,11 @@ def build(self, inputs_shape):
133133

134134
self.filter_shape = (self.kernel_size[0], self.kernel_size[1], self.in_channels, self.out_channels)
135135

136-
self.W = self._get_weights("filters", shape=self.filter_shape, init=self.W_init)
136+
self.filters = self._get_weights("filters", shape=self.filter_shape, init=self.W_init)
137137

138138
self.b_init_flag = False
139139
if self.b_init:
140-
self.b = self._get_weights("biases", shape=(self.out_channels, ), init=self.b_init)
140+
self.biases = self._get_weights("biases", shape=(self.out_channels, ), init=self.b_init)
141141
self.bias_add = tlx.ops.BiasAdd(self.data_format)
142142
self.b_init_flag = True
143143

@@ -159,10 +159,10 @@ def forward(self, inputs):
159159
self._built = True
160160
self._forward_state = True
161161

162-
outputs = self.dorefaconv2d(inputs, self.W)
162+
outputs = self.dorefaconv2d(inputs, self.filters)
163163

164164
if self.b_init_flag:
165-
outputs = self.bias_add(outputs, self.b)
165+
outputs = self.bias_add(outputs, self.biases)
166166
if self.act_init_flag:
167167
outputs = self.act(outputs)
168168

0 commit comments

Comments
 (0)