Skip to content

Commit 57eae1b

Browse files
committed
fix tensorflow_nn some padding bugs
1 parent afd1a1b commit 57eae1b

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tensorlayerx/backend/ops/tensorflow_nn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,8 @@ class DepthwiseConv2d(object):
11001100

11011101
def __init__(self, strides, padding, data_format=None, dilations=None, ksize=None, channel_multiplier=1, in_channels=None):
11021102
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
1103+
if isinstance(padding, int) or isinstance(padding, tuple):
1104+
self.padding = preprocess_padding(self.padding, '2d', self.data_format)
11031105
self.strides = strides
11041106
self.dilations = dilations
11051107
self.depthwise_conv = GroupConv2D(strides=self.strides, padding=self.padding, data_format=self.data_format,
@@ -1749,6 +1751,8 @@ def __init__(self, strides, padding, data_format, dilations, out_channel, k_size
17491751
self.strides = strides
17501752
self.dilations = dilations
17511753
self.groups = groups
1754+
if isinstance(padding, int) or isinstance(padding, tuple):
1755+
self.padding = preprocess_padding(self.padding, '2d', self.data_format)
17521756
if self.data_format == 'NHWC':
17531757
self.channels_axis = 3
17541758
else:
@@ -1789,7 +1793,6 @@ class SeparableConv1D(object):
17891793

17901794
def __init__(self, stride, padding, data_format, dilations, out_channel, k_size, in_channel, depth_multiplier):
17911795
self.data_format, self.padding = preprocess_1d_format(data_format, padding)
1792-
17931796
if self.data_format == 'NWC':
17941797
self.spatial_start_dim = 1
17951798
self.strides = (1, stride, stride, 1)
@@ -1819,6 +1822,8 @@ class SeparableConv2D(object):
18191822

18201823
def __init__(self, strides, padding, data_format, dilations, out_channel, k_size, in_channel, depth_multiplier):
18211824
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
1825+
if isinstance(padding, int) or isinstance(padding, tuple):
1826+
self.padding = preprocess_padding(self.padding, '2d', self.data_format)
18221827
self.strides = strides
18231828
self.dilations = (dilations[2], dilations[2])
18241829

0 commit comments

Comments
 (0)