@@ -1267,6 +1267,7 @@ def __init__(
12671267 self .dilations = dilations
12681268 self .name = name
12691269 self .data_format , self .padding = preprocess_2d_format (data_format , padding )
1270+ self ._padding = padding
12701271
12711272 def __call__ (self , input , filters ):
12721273 if self .data_format == 'NHWC' :
@@ -1320,13 +1321,14 @@ def __call__(self, input, filters):
13201321 output_w = input_w * strides_w
13211322 else :
13221323 if isinstance (self .padding , int ):
1323- output_h = input_h * strides_h + max (kernel_h - strides_h , 0 ) - 2 * self .padding
1324- output_w = input_w * strides_w + max (kernel_w - strides_w , 0 ) - 2 * self .padding
1325- self .padding = [[0 , 0 ], [self .padding , self .padding ],[self .padding , self .padding ], [0 , 0 ]]
1324+ output_h = input_h * strides_h + max (kernel_h - strides_h , 0 ) - 2 * self ._padding
1325+ output_w = input_w * strides_w + max (kernel_w - strides_w , 0 ) - 2 * self ._padding
1326+ self .padding = [[0 , 0 ], [self ._padding , self ._padding ],[self ._padding , self ._padding ], [0 , 0 ]]
13261327 else :
1327- output_h = input_h * strides_h + max (kernel_h - strides_h , 0 ) - 2 * self .padding [0 ]
1328- output_w = input_w * strides_w + max (kernel_w - strides_w , 0 ) - 2 * self .padding [1 ]
1329- self .padding = [[0 , 0 ], [self .padding [0 ], self .padding [0 ]],[self .padding [1 ], self .padding [1 ]], [0 , 0 ]]
1328+ print (input_h , strides_h , kernel_h , strides_h , self ._padding [0 ], self ._padding )
1329+ output_h = input_h * strides_h + max (kernel_h - strides_h , 0 ) - 2 * self ._padding [0 ]
1330+ output_w = input_w * strides_w + max (kernel_w - strides_w , 0 ) - 2 * self ._padding [1 ]
1331+ self .padding = [[0 , 0 ], [self ._padding [0 ], self ._padding [0 ]],[self ._padding [1 ], self ._padding [1 ]], [0 , 0 ]]
13301332
13311333 if self .data_format == 'NCHW' :
13321334 out_shape = (batch_size , output_channels , output_h , output_w )
0 commit comments