Skip to content

Commit 015ebdf

Browse files
committed
Fix activation and update Deconv2D pad mode
1 parent 2a730fb commit 015ebdf

File tree

5 files changed

+141
-63
lines changed

5 files changed

+141
-63
lines changed

tensorlayerx/backend/ops/mindspore_nn.py

Lines changed: 119 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,66 +1183,132 @@ def __init__(self, strides, padding, data_format, dilations=None, out_channel=No
11831183
if self.data_format == 'NHWC':
11841184
raise NotImplementedError("The optional value for data format. Currently only support “NCWH”.")
11851185

1186-
self.conv2d_transpose = P.Conv2DBackpropInput(
1187-
out_channel=self.in_channels, kernel_size=self.k_size, pad_mode=self.padding, stride=self.strides,
1188-
dilation=self.dilations, mode=1, group=1, data_format=self.data_format
1189-
)
1190-
self.shape = P.Shape()
1186+
if isinstance(self.padding, str):
1187+
self.pad_mode = self.padding
1188+
self.pad = 0
1189+
else:
1190+
self.pad_mode = 'pad'
1191+
if isinstance(self.padding, tuple):
1192+
self.padding = (self.padding[0], self.padding[0], self.padding[1], self.padding[1])
1193+
self.pad = self.padding
1194+
self.is_valid = self.pad_mode == 'valid'
1195+
self.is_same = self.pad_mode == 'same'
1196+
self.is_pad = self.pad_mode == 'pad'
1197+
# cause Conv2DTranspose's out_channel refers to Conv2D's out_channel.
1198+
self.conv2d_transpose = P.Conv2DTranspose(out_channel=in_channels,
1199+
kernel_size=self.k_size,
1200+
mode=1,
1201+
pad_mode=self.pad_mode,
1202+
pad=self.pad,
1203+
stride=self.strides,
1204+
dilation=self.dilations,
1205+
group=1)
1206+
if isinstance(self.padding, int):
1207+
self.padding_top, self.padding_bottom, self.padding_left, self.padding_right = (self.padding,) * 4
1208+
else:
1209+
self.padding_top, self.padding_bottom = (self.padding[0],) * 2
1210+
self.padding_left, self.padding_right = (self.padding[1],) * 2
11911211

1192-
def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size):
1212+
def construct(self, x, filters):
1213+
n, _, h, w = P.Shape()(x)
1214+
h_out = self._deconv_output_length(self.is_valid, self.is_same, self.is_pad, h, self.k_size[0],
1215+
self.strides[0], self.dilations[0], self.padding_top + self.padding_bottom)
1216+
w_out = self._deconv_output_length(self.is_valid, self.is_same, self.is_pad, w, self.k_size[1],
1217+
self.strides[1], self.dilations[1], self.padding_left + self.padding_right)
1218+
output = self.conv2d_transpose(x, filters, (n, self.out_channel, h_out, w_out))
1219+
return output
1220+
1221+
def _deconv_output_length(self, is_valid, is_same, is_pad, input_length, filter_size, stride_size, dilation_size,
1222+
padding):
1223+
"""Calculate the width and height of output."""
11931224
length = 0
11941225
filter_size = filter_size + (filter_size - 1) * (dilation_size - 1)
1195-
1196-
if self.padding == 'same':
1226+
if is_valid:
1227+
if filter_size - stride_size > 0:
1228+
length = input_length * stride_size + filter_size - stride_size
1229+
else:
1230+
length = input_length * stride_size
1231+
elif is_same:
11971232
length = input_length * stride_size
1198-
elif self.padding == 'valid':
1199-
length = input_length * stride_size + max(filter_size - stride_size, 0)
1233+
elif is_pad:
1234+
length = input_length * stride_size - padding + filter_size - stride_size
12001235

12011236
return length
12021237

1203-
def construct(self, x, filters):
1204-
if self.data_format == 'NHWC':
1205-
h_axis, w_axis = 1, 2
1206-
n, h, w, _ = self.shape(x)
1207-
else:
1208-
h_axis, w_axis = 2, 3
1209-
n, _, h, w = self.shape(x)
1210-
1211-
if isinstance(self.strides, int):
1212-
strides_h = self.strides
1213-
strides_w = self.strides
1214-
else:
1215-
strides_list = list(self.strides)
1216-
if len(strides_list) == 2:
1217-
strides_h = strides_list[0]
1218-
strides_w = strides_list[1]
1219-
elif len(strides_list) == 4:
1220-
strides_h = strides_list[h_axis]
1221-
strides_w = strides_list[w_axis]
1222-
1223-
if self.dilations is not None:
1224-
if isinstance(self.dilations, int):
1225-
dilations_h = self.dilations
1226-
dilations_w = self.dilations
1227-
else:
1228-
dilations_list = list(self.dilations)
1229-
if len(dilations_list) == 2:
1230-
dilations_h = dilations_list[0]
1231-
dilations_w = dilations_list[1]
1232-
elif len(dilations_list) == 4:
1233-
dilations_h = dilations_list[h_axis]
1234-
dilations_w = dilations_list[w_axis]
1235-
1236-
h_out = self._deconv_output_length(h, self.k_size[0], strides_h, dilations_h)
1237-
w_out = self._deconv_output_length(w, self.k_size[1], strides_w, dilations_w)
1238-
1239-
if self.data_format == 'NCHW':
1240-
output_size = (n, self.out_channel, h_out, w_out)
1241-
else:
1242-
output_size = (n, h_out, w_out, self.out_channel)
1243-
output = self.conv2d_transpose(x, filters, output_size)
1244-
1245-
return output
1238+
# class Conv2d_transpose(Cell):
1239+
#
1240+
# def __init__(self, strides, padding, data_format, dilations=None, out_channel=None, k_size=None, in_channels=None):
1241+
# super(Conv2d_transpose, self).__init__()
1242+
# self.data_format, self.padding = preprocess_2d_format(data_format, padding)
1243+
# self.in_channels = in_channels
1244+
# self.out_channel = out_channel
1245+
# self.k_size = k_size
1246+
# self.strides = strides
1247+
# self.dilations = dilations
1248+
#
1249+
# if self.data_format == 'NHWC':
1250+
# raise NotImplementedError("The optional value for data format. Currently only support “NCWH”.")
1251+
#
1252+
# self.conv2d_transpose = P.Conv2DBackpropInput(
1253+
# out_channel=self.in_channels, kernel_size=self.k_size, pad_mode=self.padding, stride=self.strides,
1254+
# dilation=self.dilations, mode=1, group=1, data_format=self.data_format
1255+
# )
1256+
# self.shape = P.Shape()
1257+
#
1258+
# def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size):
1259+
# length = 0
1260+
# filter_size = filter_size + (filter_size - 1) * (dilation_size - 1)
1261+
#
1262+
# if self.padding == 'same':
1263+
# length = input_length * stride_size
1264+
# elif self.padding == 'valid':
1265+
# length = input_length * stride_size + max(filter_size - stride_size, 0)
1266+
#
1267+
# return length
1268+
#
1269+
# def construct(self, x, filters):
1270+
# if self.data_format == 'NHWC':
1271+
# h_axis, w_axis = 1, 2
1272+
# n, h, w, _ = self.shape(x)
1273+
# else:
1274+
# h_axis, w_axis = 2, 3
1275+
# n, _, h, w = self.shape(x)
1276+
#
1277+
# if isinstance(self.strides, int):
1278+
# strides_h = self.strides
1279+
# strides_w = self.strides
1280+
# else:
1281+
# strides_list = list(self.strides)
1282+
# if len(strides_list) == 2:
1283+
# strides_h = strides_list[0]
1284+
# strides_w = strides_list[1]
1285+
# elif len(strides_list) == 4:
1286+
# strides_h = strides_list[h_axis]
1287+
# strides_w = strides_list[w_axis]
1288+
#
1289+
# if self.dilations is not None:
1290+
# if isinstance(self.dilations, int):
1291+
# dilations_h = self.dilations
1292+
# dilations_w = self.dilations
1293+
# else:
1294+
# dilations_list = list(self.dilations)
1295+
# if len(dilations_list) == 2:
1296+
# dilations_h = dilations_list[0]
1297+
# dilations_w = dilations_list[1]
1298+
# elif len(dilations_list) == 4:
1299+
# dilations_h = dilations_list[h_axis]
1300+
# dilations_w = dilations_list[w_axis]
1301+
#
1302+
# h_out = self._deconv_output_length(h, self.k_size[0], strides_h, dilations_h)
1303+
# w_out = self._deconv_output_length(w, self.k_size[1], strides_w, dilations_w)
1304+
#
1305+
# if self.data_format == 'NCHW':
1306+
# output_size = (n, self.out_channel, h_out, w_out)
1307+
# else:
1308+
# output_size = (n, h_out, w_out, self.out_channel)
1309+
# output = self.conv2d_transpose(x, filters, output_size)
1310+
#
1311+
# return output
12461312

12471313

12481314
def conv2d_transpose(

tensorlayerx/backend/ops/tensorflow_nn.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1310,13 +1310,23 @@ def __call__(self, input, filters):
13101310
kernel_h = kernel_h + (kernel_h - 1) * (dilations_h - 1)
13111311
kernel_w = kernel_w + (kernel_w - 1) * (dilations_w - 1)
13121312

1313-
assert self.padding in {'SAME', 'VALID'}
1313+
if tf.__version__ < '2.4.0' and not isinstance(self.padding, str):
1314+
assert self.padding in {'SAME', 'VALID'}
13141315
if self.padding == 'VALID':
13151316
output_h = input_h * strides_h + max(kernel_h - strides_h, 0)
13161317
output_w = input_w * strides_w + max(kernel_w - strides_w, 0)
13171318
elif self.padding == 'SAME':
13181319
output_h = input_h * strides_h
13191320
output_w = input_w * strides_w
1321+
else:
1322+
if isinstance(self.padding, int):
1323+
output_h = input_h * strides_h + max(kernel_h - strides_h, 0) - 2 * self.padding
1324+
output_w = input_w * strides_w + max(kernel_w - strides_w, 0) - 2 * self.padding
1325+
self.padding = [[0, 0], [self.padding, self.padding],[self.padding, self.padding], [0, 0]]
1326+
else:
1327+
output_h = input_h * strides_h + max(kernel_h - strides_h, 0) - 2 * self.padding[0]
1328+
output_w = input_w * strides_w + max(kernel_w - strides_w, 0) - 2* self.padding[1]
1329+
self.padding = [[0, 0], [self.padding[0], self.padding[0]],[self.padding[1], self.padding[1]], [0, 0]]
13201330

13211331
if self.data_format == 'NCHW':
13221332
out_shape = (batch_size, output_channels, output_h, output_w)

tensorlayerx/backend/ops/torch_nn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,7 +1240,7 @@ def __call__(self, input, filters):
12401240
out = F.conv_transpose1d(
12411241
input,
12421242
weight=filters,
1243-
padding=0,
1243+
padding=(0 if isinstance(self.padding, str) else self.padding),
12441244
stride=self.stride,
12451245
dilation=self.dilations
12461246
)
@@ -1319,7 +1319,7 @@ def __call__(self, input, filters):
13191319
out = F.conv_transpose2d(
13201320
input,
13211321
weight=filters,
1322-
padding=0,
1322+
padding=(0 if isinstance(self.padding, str) else self.padding),
13231323
stride=self.strides,
13241324
dilation=self.dilations
13251325
)
@@ -1399,7 +1399,7 @@ def __call__(self, input, filters):
13991399
out = F.conv_transpose3d(
14001400
input,
14011401
weight=filters,
1402-
padding=0,
1402+
padding=(0 if isinstance(self.padding, str) else self.padding),
14031403
stride=self.strides,
14041404
dilation=self.dilations
14051405
)

tensorlayerx/nn/core/common.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,15 @@ def processing_act(act):
9898
out_act = str_act
9999
elif isinstance(act, str):
100100
out_act = str_act()
101-
# Processing classes or functions as input, activation functions without parameters
102-
elif type(act) == type(tlx.nn.ReLU):
103-
out_act = act()
104-
# Processing class or function as input, activation function with parameters
105101
else:
106-
out_act = act
102+
# Processing classes or functions as input, activation functions without parameters
103+
try:
104+
out_act = act()
105+
# Processing class or function as input, activation function with parameters
106+
except:
107+
out_act = act
107108
else:
109+
# Processing act is None
108110
out_act = act
109111
return out_act
110112

tensorlayerx/nn/layers/convolution/simplified_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ class ConvTranspose2d(Module):
611611
Specifying the dilation rate to use for dilated convolution.
612612
act : activation function
613613
The activation function of this layer.
614-
padding : str
614+
padding : int, tuple or str
615615
The padding algorithm type: "SAME" or "VALID".
616616
data_format : str
617617
"channels_last" (NHWC, default) or "channels_first" (NCHW).

0 commit comments

Comments
 (0)