@@ -1183,66 +1183,132 @@ def __init__(self, strides, padding, data_format, dilations=None, out_channel=No
11831183 if self .data_format == 'NHWC' :
11841184 raise NotImplementedError ("The optional value for data format. Currently only support “NCWH”." )
11851185
1186- self .conv2d_transpose = P .Conv2DBackpropInput (
1187- out_channel = self .in_channels , kernel_size = self .k_size , pad_mode = self .padding , stride = self .strides ,
1188- dilation = self .dilations , mode = 1 , group = 1 , data_format = self .data_format
1189- )
1190- self .shape = P .Shape ()
1186+ if isinstance (self .padding , str ):
1187+ self .pad_mode = self .padding
1188+ self .pad = 0
1189+ else :
1190+ self .pad_mode = 'pad'
1191+ if isinstance (self .padding , tuple ):
1192+ self .padding = (self .padding [0 ], self .padding [0 ], self .padding [1 ], self .padding [1 ])
1193+ self .pad = self .padding
1194+ self .is_valid = self .pad_mode == 'valid'
1195+ self .is_same = self .pad_mode == 'same'
1196+ self .is_pad = self .pad_mode == 'pad'
1197+ # cause Conv2DTranspose's out_channel refers to Conv2D's out_channel.
1198+ self .conv2d_transpose = P .Conv2DTranspose (out_channel = in_channels ,
1199+ kernel_size = self .k_size ,
1200+ mode = 1 ,
1201+ pad_mode = self .pad_mode ,
1202+ pad = self .pad ,
1203+ stride = self .strides ,
1204+ dilation = self .dilations ,
1205+ group = 1 )
1206+ if isinstance (self .padding , int ):
1207+ self .padding_top , self .padding_bottom , self .padding_left , self .padding_right = (self .padding ,) * 4
1208+ else :
1209+ self .padding_top , self .padding_bottom = (self .padding [0 ],) * 2
1210+ self .padding_left , self .padding_right = (self .padding [1 ],) * 2
11911211
1192- def _deconv_output_length (self , input_length , filter_size , stride_size , dilation_size ):
1212+ def construct (self , x , filters ):
1213+ n , _ , h , w = P .Shape ()(x )
1214+ h_out = self ._deconv_output_length (self .is_valid , self .is_same , self .is_pad , h , self .k_size [0 ],
1215+ self .strides [0 ], self .dilations [0 ], self .padding_top + self .padding_bottom )
1216+ w_out = self ._deconv_output_length (self .is_valid , self .is_same , self .is_pad , w , self .k_size [1 ],
1217+ self .strides [1 ], self .dilations [1 ], self .padding_left + self .padding_right )
1218+ output = self .conv2d_transpose (x , filters , (n , self .out_channel , h_out , w_out ))
1219+ return output
1220+
1221+ def _deconv_output_length (self , is_valid , is_same , is_pad , input_length , filter_size , stride_size , dilation_size ,
1222+ padding ):
1223+ """Calculate the width and height of output."""
11931224 length = 0
11941225 filter_size = filter_size + (filter_size - 1 ) * (dilation_size - 1 )
1195-
1196- if self .padding == 'same' :
1226+ if is_valid :
1227+ if filter_size - stride_size > 0 :
1228+ length = input_length * stride_size + filter_size - stride_size
1229+ else :
1230+ length = input_length * stride_size
1231+ elif is_same :
11971232 length = input_length * stride_size
1198- elif self . padding == 'valid' :
1199- length = input_length * stride_size + max ( filter_size - stride_size , 0 )
1233+ elif is_pad :
1234+ length = input_length * stride_size - padding + filter_size - stride_size
12001235
12011236 return length
12021237
1203- def construct (self , x , filters ):
1204- if self .data_format == 'NHWC' :
1205- h_axis , w_axis = 1 , 2
1206- n , h , w , _ = self .shape (x )
1207- else :
1208- h_axis , w_axis = 2 , 3
1209- n , _ , h , w = self .shape (x )
1210-
1211- if isinstance (self .strides , int ):
1212- strides_h = self .strides
1213- strides_w = self .strides
1214- else :
1215- strides_list = list (self .strides )
1216- if len (strides_list ) == 2 :
1217- strides_h = strides_list [0 ]
1218- strides_w = strides_list [1 ]
1219- elif len (strides_list ) == 4 :
1220- strides_h = strides_list [h_axis ]
1221- strides_w = strides_list [w_axis ]
1222-
1223- if self .dilations is not None :
1224- if isinstance (self .dilations , int ):
1225- dilations_h = self .dilations
1226- dilations_w = self .dilations
1227- else :
1228- dilations_list = list (self .dilations )
1229- if len (dilations_list ) == 2 :
1230- dilations_h = dilations_list [0 ]
1231- dilations_w = dilations_list [1 ]
1232- elif len (dilations_list ) == 4 :
1233- dilations_h = dilations_list [h_axis ]
1234- dilations_w = dilations_list [w_axis ]
1235-
1236- h_out = self ._deconv_output_length (h , self .k_size [0 ], strides_h , dilations_h )
1237- w_out = self ._deconv_output_length (w , self .k_size [1 ], strides_w , dilations_w )
1238-
1239- if self .data_format == 'NCHW' :
1240- output_size = (n , self .out_channel , h_out , w_out )
1241- else :
1242- output_size = (n , h_out , w_out , self .out_channel )
1243- output = self .conv2d_transpose (x , filters , output_size )
1244-
1245- return output
1238+ # class Conv2d_transpose(Cell):
1239+ #
1240+ # def __init__(self, strides, padding, data_format, dilations=None, out_channel=None, k_size=None, in_channels=None):
1241+ # super(Conv2d_transpose, self).__init__()
1242+ # self.data_format, self.padding = preprocess_2d_format(data_format, padding)
1243+ # self.in_channels = in_channels
1244+ # self.out_channel = out_channel
1245+ # self.k_size = k_size
1246+ # self.strides = strides
1247+ # self.dilations = dilations
1248+ #
1249+ # if self.data_format == 'NHWC':
1250+ # raise NotImplementedError("The optional value for data format. Currently only support “NCWH”.")
1251+ #
1252+ # self.conv2d_transpose = P.Conv2DBackpropInput(
1253+ # out_channel=self.in_channels, kernel_size=self.k_size, pad_mode=self.padding, stride=self.strides,
1254+ # dilation=self.dilations, mode=1, group=1, data_format=self.data_format
1255+ # )
1256+ # self.shape = P.Shape()
1257+ #
1258+ # def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size):
1259+ # length = 0
1260+ # filter_size = filter_size + (filter_size - 1) * (dilation_size - 1)
1261+ #
1262+ # if self.padding == 'same':
1263+ # length = input_length * stride_size
1264+ # elif self.padding == 'valid':
1265+ # length = input_length * stride_size + max(filter_size - stride_size, 0)
1266+ #
1267+ # return length
1268+ #
1269+ # def construct(self, x, filters):
1270+ # if self.data_format == 'NHWC':
1271+ # h_axis, w_axis = 1, 2
1272+ # n, h, w, _ = self.shape(x)
1273+ # else:
1274+ # h_axis, w_axis = 2, 3
1275+ # n, _, h, w = self.shape(x)
1276+ #
1277+ # if isinstance(self.strides, int):
1278+ # strides_h = self.strides
1279+ # strides_w = self.strides
1280+ # else:
1281+ # strides_list = list(self.strides)
1282+ # if len(strides_list) == 2:
1283+ # strides_h = strides_list[0]
1284+ # strides_w = strides_list[1]
1285+ # elif len(strides_list) == 4:
1286+ # strides_h = strides_list[h_axis]
1287+ # strides_w = strides_list[w_axis]
1288+ #
1289+ # if self.dilations is not None:
1290+ # if isinstance(self.dilations, int):
1291+ # dilations_h = self.dilations
1292+ # dilations_w = self.dilations
1293+ # else:
1294+ # dilations_list = list(self.dilations)
1295+ # if len(dilations_list) == 2:
1296+ # dilations_h = dilations_list[0]
1297+ # dilations_w = dilations_list[1]
1298+ # elif len(dilations_list) == 4:
1299+ # dilations_h = dilations_list[h_axis]
1300+ # dilations_w = dilations_list[w_axis]
1301+ #
1302+ # h_out = self._deconv_output_length(h, self.k_size[0], strides_h, dilations_h)
1303+ # w_out = self._deconv_output_length(w, self.k_size[1], strides_w, dilations_w)
1304+ #
1305+ # if self.data_format == 'NCHW':
1306+ # output_size = (n, self.out_channel, h_out, w_out)
1307+ # else:
1308+ # output_size = (n, h_out, w_out, self.out_channel)
1309+ # output = self.conv2d_transpose(x, filters, output_size)
1310+ #
1311+ # return output
12461312
12471313
12481314def conv2d_transpose (
0 commit comments