Skip to content

Commit 946d19f

Browse files
committed
Added Conv Pooling padding can be int tuple or str.
1 parent 4453a4d commit 946d19f

File tree

7 files changed

+189
-30
lines changed

7 files changed

+189
-30
lines changed

tensorlayerx/backend/ops/mindspore_nn.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,41 @@ def padding_format(padding):
4444
padding = "valid"
4545
elif padding == None:
4646
padding = None
47+
elif isinstance(padding, tuple) or isinstance(padding, int):
48+
return padding
4749
else:
4850
raise Exception("Unsupported padding: " + str(padding))
4951
return padding
5052

5153

54+
def preprocess_padding(padding, dim='2d'):
55+
check_padding(padding, dim)
56+
if dim == '1d':
57+
out_padding = (0, 0, padding, padding)
58+
elif dim == '2d':
59+
if isinstance(padding, tuple):
60+
out_padding = (padding[0], padding[0], padding[1], padding[1])
61+
else:
62+
out_padding = padding
63+
elif dim == '3d':
64+
if isinstance(padding, tuple):
65+
out_padding = (padding[0], padding[0], padding[1], padding[1], padding[2], padding[2])
66+
else:
67+
out_padding = padding
68+
else:
69+
raise RuntimeError("Unsupported input dimensions.")
70+
return out_padding
71+
72+
73+
def check_padding(padding, dim='2d'):
74+
if dim == '1d' and isinstance(object, tuple):
75+
raise RuntimeError("expected padding to be a single integer value or a list of 1 values to match the convolution dimensions.")
76+
if dim == '2d' and isinstance(padding, tuple) and len(padding) > 2:
77+
raise RuntimeError("expected padding to be a single integer value or a list of 2 values to match the convolution dimensions.")
78+
if dim == '3d' and isinstance(padding, tuple) and len(padding) > 3:
79+
raise RuntimeError("expected padding to be a single integer value or a list of 3 values to match the convolution dimensions.")
80+
81+
5282
def preprocess_1d_format(data_format, padding):
5383
"""
5484
Checks that the 1-D dataformat format correspond format.
@@ -458,20 +488,24 @@ def bias_add(x, bias):
458488
"""
459489
raise NotImplementedError
460490

461-
462491
class Conv1D(Cell):
463492

464493
def __init__(self, stride, padding, data_format='NWC', dilations=None, out_channel=None, k_size=None):
465494
super(Conv1D, self).__init__()
466-
self.data_format, self.padding = preprocess_1d_format(data_format, padding)
495+
self.data_format, self.pad_mode = preprocess_1d_format(data_format, padding)
496+
self.padding = 0
467497
self.stride = (1, stride)
468498
self.dilations = (1, dilations)
469499
self.k_size = (1, k_size)
470500
self.out_channel = out_channel
471501

502+
if isinstance(self.pad_mode, int):
503+
self.padding = preprocess_padding(self.pad_mode, '1d')
504+
self.pad_mode = "pad"
505+
472506
self.conv2d = P.Conv2D(
473-
out_channel=self.out_channel, kernel_size=self.k_size, pad_mode=self.padding, stride=self.stride,
474-
dilation=self.dilations, mode=1, group=1
507+
out_channel=self.out_channel, kernel_size=self.k_size, pad_mode=self.pad_mode, pad=self.padding,
508+
stride=self.stride, dilation=self.dilations, mode=1, group=1
475509
)
476510

477511
self.expand_dims = P.ExpandDims()
@@ -528,16 +562,21 @@ class Conv2D(Cell):
528562

529563
def __init__(self, strides, padding, data_format='NHWC', dilations=None, out_channel=None, k_size=None):
530564
super(Conv2D, self).__init__()
531-
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
565+
self.data_format, self.pad_mode = preprocess_2d_format(data_format, padding)
566+
self.padding = 0
532567
if self.data_format is 'NHWC':
533568
self._stride = (strides[1], strides[2])
534569
self._dilation = (dilations[1], dilations[2])
535570
elif self.data_format is 'NCHW':
536571
self._stride = (strides[2], strides[3])
537572
self._dilation = (dilations[2], dilations[3])
538573

574+
if isinstance(self.pad_mode, int) or isinstance(self.pad_mode, tuple):
575+
self.padding = preprocess_padding(self.pad_mode, '2d')
576+
self.pad_mode = "pad"
577+
539578
self.conv2d = P.Conv2D(
540-
out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self._stride,
579+
out_channel=out_channel, kernel_size=k_size, pad_mode=self.pad_mode, pad=self.padding, stride=self._stride,
541580
dilation=self._dilation, mode=1, group=1, data_format=self.data_format
542581
)
543582

@@ -578,8 +617,8 @@ class Conv3D(Cell):
578617

579618
def __init__(self, strides, padding, data_format='NDHWC', dilations=None, out_channel=None, k_size=None):
580619
super(Conv3D, self).__init__()
581-
self.data_format, self.padding = preprocess_3d_format(data_format, padding)
582-
620+
self.data_format, self.pad_mode = preprocess_3d_format(data_format, padding)
621+
self.padding = 0
583622
if self.data_format is 'NDHWC':
584623
self.ms_stride = (strides[1], strides[2], strides[3])
585624
self.ms_dilation = (dilations[1], dilations[2], dilations[3])
@@ -588,9 +627,13 @@ def __init__(self, strides, padding, data_format='NDHWC', dilations=None, out_ch
588627
self.ms_stride = (strides[2], strides[3], strides[4])
589628
self.ms_dilation = (dilations[2], dilations[3], dilations[4])
590629

630+
if isinstance(self.pad_mode, int) or isinstance(self.pad_mode, tuple):
631+
self.padding = preprocess_padding(self.pad_mode, '3d')
632+
self.pad_mode = "pad"
633+
591634
self.conv3d = P.Conv3D(
592-
out_channel=out_channel, kernel_size=k_size, pad_mode=self.padding, stride=self.ms_stride,
593-
dilation=self.ms_dilation, data_format=data_format
635+
out_channel=out_channel, kernel_size=k_size, pad_mode=self.pad_mode, pad=self.padding, stride=self.ms_stride,
636+
dilation=self.ms_dilation, data_format=self.data_format
594637
)
595638

596639
def construct(self, input, filters):

tensorlayerx/backend/ops/paddle_nn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def padding_format(padding):
3535
padding = "VALID"
3636
elif padding == None:
3737
padding = None
38+
elif isinstance(padding, tuple) or isinstance(padding, int):
39+
return padding
3840
else:
3941
raise Exception("Unsupported padding: " + str(padding))
4042
return padding

tensorlayerx/backend/ops/tensorflow_nn.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,62 @@ def padding_format(padding):
3434
padding = "VALID"
3535
elif padding == None:
3636
padding = None
37+
elif isinstance(padding, tuple) or isinstance(padding, int):
38+
return padding
3739
else:
3840
raise Exception("Unsupported padding: " + str(padding))
3941
return padding
4042

4143

44+
def preprocess_padding(padding, dim='2d', data_format='NHWC'):
45+
# When explicit padding is used and data_format is "NHWC",
46+
# this should be in the form [[0, 0], [pad_top, pad_bottom],[pad_left, pad_right], [0, 0]].
47+
# When explicit padding used and data_format is "NCHW",
48+
# this should be in the form [[0, 0], [0, 0],[pad_top, pad_bottom], [pad_left, pad_right]].
49+
check_padding(padding, dim)
50+
if dim == '1d':
51+
if data_format == 'NWC':
52+
out_padding = [[0, 0], [padding, padding], [0, 0]]
53+
else:
54+
out_padding = [[0, 0], [0, 0], [padding, padding]]
55+
elif dim == '2d':
56+
if isinstance(padding, int):
57+
if data_format == 'NHWC':
58+
out_padding = [[0, 0], [padding, padding], [padding, padding], [0, 0]]
59+
else:
60+
out_padding = [[0, 0], [0, 0],[padding, padding], [padding, padding]]
61+
elif isinstance(padding, tuple):
62+
if data_format == 'NHWC':
63+
out_padding = [[0, 0], [padding[0], padding[0]], [padding[1], padding[1]], [0, 0]]
64+
else:
65+
out_padding = [[0, 0], [0, 0],[padding[0], padding[0]], [padding[1], padding[1]]]
66+
elif dim == '3d':
67+
if isinstance(padding, int):
68+
if data_format == 'NDHWC':
69+
out_padding = [[0, 0], [padding, padding], [padding, padding], [padding, padding], [0, 0]]
70+
else:
71+
out_padding = [[0, 0], [0, 0], [padding, padding], [padding, padding], [padding, padding]]
72+
elif isinstance(padding, tuple):
73+
if data_format == 'NDHWC':
74+
out_padding = [[0, 0], [padding[0], padding[0]], [padding[1], padding[1]], [padding[2], padding[2]], [0, 0]]
75+
else:
76+
out_padding = [[0, 0], [0, 0], [padding[0], padding[0]], [padding[1], padding[1]], [padding[2], padding[2]]]
77+
else:
78+
raise RuntimeError("Unsupported input dimensions.")
79+
return out_padding
80+
81+
82+
def check_padding(padding, dim='2d'):
83+
if dim == '1d' and isinstance(object, tuple):
84+
raise RuntimeError("expected padding to be a single integer value or a list of 1 values to match the convolution dimensions.")
85+
if dim == '2d' and isinstance(object, tuple) and len(padding) > 2:
86+
raise RuntimeError("expected padding to be a single integer value or a list of 2 values to match the convolution dimensions.")
87+
if dim == '3d' and isinstance(object, tuple) and len(padding) > 3:
88+
raise RuntimeError("expected padding to be a single integer value or a list of 3 values to match the convolution dimensions.")
89+
90+
91+
92+
4293
def preprocess_1d_format(data_format, padding):
4394
"""
4495
Checks that the 1-D dataformat format correspond format.
@@ -438,8 +489,16 @@ def __init__(self, stride, padding, data_format='NWC', dilations=None, out_chann
438489
self.stride = stride
439490
self.dilations = dilations
440491
self.data_format, self.padding = preprocess_1d_format(data_format, padding)
492+
self.pad_value = None
493+
494+
if isinstance(padding, int):
495+
self.pad_value = preprocess_padding(self.padding, '1d', self.data_format)
496+
self.padding = 'VALID'
497+
441498

442499
def __call__(self, input, filters):
500+
if self.pad_value is not None:
501+
input = tf.pad(input, paddings=self.pad_value)
443502
outputs = tf.nn.conv1d(
444503
input=input,
445504
filters=filters,
@@ -501,6 +560,9 @@ def __init__(self, strides, padding, data_format='NHWC', dilations=None, out_cha
501560
self.dilations = dilations
502561
self.data_format, self.padding = preprocess_2d_format(data_format, padding)
503562

563+
if isinstance(padding, int) or isinstance(padding, tuple):
564+
self.padding = preprocess_padding(self.padding, '2d', self.data_format)
565+
504566
def __call__(self, input, filters):
505567
outputs = tf.nn.conv2d(
506568
input=input,
@@ -559,8 +621,15 @@ def __init__(self, strides, padding, data_format='NDHWC', dilations=None, out_ch
559621
self.strides = strides
560622
self.dilations = dilations
561623
self.data_format, self.padding = preprocess_3d_format(data_format, padding)
624+
self.pad_value = None
625+
626+
if isinstance(padding, int) or isinstance(padding, tuple):
627+
self.pad_value = preprocess_padding(self.padding, '3d', self.data_format)
628+
self.padding = 'VALID'
562629

563630
def __call__(self, input, filters):
631+
if self.pad_value is not None:
632+
input = tf.pad(input, paddings=self.pad_value)
564633
outputs = tf.nn.conv3d(
565634
input=input,
566635
filters=filters,
@@ -676,8 +745,14 @@ def __init__(self, ksize, strides, padding, data_format=None):
676745
self.data_format, self.padding = preprocess_1d_format(data_format=data_format, padding=padding)
677746
self.ksize = ksize
678747
self.strides = strides
748+
self.padding_value = None
749+
if not isinstance(self.padding, str):
750+
self.padding_value = preprocess_padding(self.padding, '1d', self.data_format)
751+
self.padding = "VALID"
679752

680753
def __call__(self, inputs):
754+
if self.padding_value is not None:
755+
inputs = tf.pad(inputs, self.padding_value)
681756
outputs = tf.nn.max_pool(
682757
input=inputs, ksize=self.ksize, strides=self.strides, padding=self.padding, data_format=self.data_format
683758
)
@@ -695,10 +770,22 @@ def __init__(self, ksize, strides, padding, data_format=None):
695770
def __call__(self, inputs):
696771
if len(inputs.shape) == 3:
697772
self.data_format, self.padding = preprocess_1d_format(data_format=self.data_format, padding=self.padding)
773+
if not isinstance(self.padding, str):
774+
self.padding_value = preprocess_padding(self.padding, '1d', self.data_format)
775+
self.padding = "VALID"
776+
inputs = tf.pad(inputs, self.padding_value)
698777
elif len(inputs.shape) == 4:
699778
self.data_format, self.padding = preprocess_2d_format(data_format=self.data_format, padding=self.padding)
779+
if not isinstance(self.padding, str):
780+
self.padding_value = preprocess_padding(self.padding, '2d', self.data_format)
781+
self.padding = "VALID"
782+
inputs = tf.pad(inputs, self.padding_value)
700783
elif len(inputs.shape) == 5:
701784
self.data_format, self.padding = preprocess_3d_format(data_format=self.data_format, padding=self.padding)
785+
if not isinstance(self.padding, str):
786+
self.padding_value = preprocess_padding(self.padding, '3d', self.data_format)
787+
self.padding = "VALID"
788+
inputs = tf.pad(inputs, self.padding_value)
702789

703790
outputs = tf.nn.max_pool(
704791
input=inputs, ksize=self.ksize, strides=self.strides, padding=self.padding, data_format=self.data_format
@@ -749,8 +836,14 @@ def __init__(self, ksize, strides, padding, data_format=None):
749836
self.data_format, self.padding = preprocess_1d_format(data_format=data_format, padding=padding)
750837
self.ksize = ksize
751838
self.strides = strides
839+
self.padding_value = None
840+
if not isinstance(self.padding, str):
841+
self.padding_value = preprocess_padding(self.padding, '1d', self.data_format)
842+
self.padding = "VALID"
752843

753844
def __call__(self, inputs):
845+
if self.padding_value is not None:
846+
inputs = tf.pad(inputs, self.padding_value)
754847
outputs = tf.nn.pool(
755848
input=inputs,
756849
window_shape=self.ksize,
@@ -769,8 +862,14 @@ def __init__(self, ksize, strides, padding, data_format=None):
769862
self.strides = strides
770863
self.data_format = data_format
771864
self.padding = padding_format(padding)
865+
self.padding_value = None
866+
if not isinstance(self.padding, str):
867+
self.padding_value = preprocess_padding(self.padding, '2d', self.data_format)
868+
self.padding = "VALID"
772869

773870
def __call__(self, inputs):
871+
if self.padding_value is not None:
872+
inputs = tf.pad(inputs, self.padding_value)
774873
outputs = tf.nn.avg_pool(
775874
input=inputs, ksize=self.ksize, strides=self.strides, padding=self.padding, data_format=self.data_format
776875
)
@@ -819,8 +918,14 @@ def __init__(self, ksize, strides, padding, data_format=None):
819918
self.data_format, self.padding = preprocess_3d_format(data_format, padding)
820919
self.ksize = ksize
821920
self.strides = strides
921+
self.padding_value = None
922+
if not isinstance(self.padding, str):
923+
self.padding_value = preprocess_padding(self.padding, '3d', self.data_format)
924+
self.padding = "VALID"
822925

823926
def __call__(self, inputs):
927+
if self.padding_value is not None:
928+
inputs = tf.pad(inputs, self.padding_value)
824929
outputs = tf.nn.max_pool3d(
825930
input=inputs,
826931
ksize=self.ksize,
@@ -876,8 +981,14 @@ def __init__(self, ksize, strides, padding, data_format=None):
876981
self.data_format, self.padding = preprocess_3d_format(data_format, padding)
877982
self.ksize = ksize
878983
self.strides = strides
984+
self.padding_value = None
985+
if not isinstance(self.padding, str):
986+
self.padding_value = preprocess_padding(self.padding, '3d', self.data_format)
987+
self.padding = "VALID"
879988

880989
def __call__(self, inputs):
990+
if self.padding_value is not None:
991+
inputs = tf.pad(inputs, self.padding_value)
881992
outputs = tf.nn.avg_pool3d(
882993
input=inputs,
883994
ksize=self.ksize,

tensorlayerx/backend/ops/torch_nn.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def padding_format(padding):
2727
padding = "valid"
2828
elif padding == None:
2929
padding = None
30+
elif isinstance(padding, tuple) or isinstance(padding, int):
31+
return padding
3032
else:
3133
raise Exception("Unsupported padding: " + str(padding))
3234
return padding
@@ -798,17 +800,17 @@ def __call__(self, inputs):
798800
if self.padding in ['SAME', 'same']:
799801
out = self.maxpool1d_same_padding(inputs)
800802
else:
801-
out = F.max_pool1d(inputs, self.ksize, self.strides)
803+
out = F.max_pool1d(inputs, self.ksize, self.strides, padding=self.padding)
802804
if len(inputs.shape) == 4:
803805
if self.padding in ['SAME', 'same']:
804806
out = self.maxpool2d_same_padding(inputs)
805807
else:
806-
out = F.max_pool2d(inputs, self.ksize, (self.strides[1], self.strides[2]))
808+
out = F.max_pool2d(inputs, self.ksize, (self.strides[1], self.strides[2]), padding=self.padding)
807809
if len(inputs.shape) == 5:
808810
if self.padding in ['SAME', 'same']:
809811
out = self.maxpool3d_same_padding(inputs)
810812
else:
811-
out = F.max_pool3d(inputs, self.ksize, (self.strides[1], self.strides[2], self.strides[3]))
813+
out = F.max_pool3d(inputs, self.ksize, (self.strides[1], self.strides[2], self.strides[3]), padding=self.padding)
812814

813815
if self.data_format == 'channels_last':
814816
return nchw_to_nhwc(out)
@@ -900,17 +902,17 @@ def __call__(self, inputs):
900902
if self.padding in ['SAME', 'same']:
901903
out = self.avgpool1d_same_padding(inputs)
902904
else:
903-
out = F.avg_pool1d(inputs, self.ksize, self.strides)
905+
out = F.avg_pool1d(inputs, self.ksize, self.strides, padding=self.padding)
904906
if len(inputs.shape) == 4:
905907
if self.padding in ['SAME', 'same']:
906908
out = self.avgpool2d_same_padding(inputs)
907909
else:
908-
out = F.avg_pool2d(inputs, self.ksize, (self.strides[1], self.strides[2]))
910+
out = F.avg_pool2d(inputs, self.ksize, (self.strides[1], self.strides[2]), padding=self.padding)
909911
if len(inputs.shape) == 5:
910912
if self.padding in ['SAME', 'same']:
911913
out = self.avgpool3d_same_padding(inputs)
912914
else:
913-
out = F.avg_pool3d(inputs, self.ksize, (self.strides[1], self.strides[2], self.strides[3]))
915+
out = F.avg_pool3d(inputs, self.ksize, (self.strides[1], self.strides[2], self.strides[3]), padding=self.padding)
914916

915917
if self.data_format == 'channels_last':
916918
return nchw_to_nhwc(out)

tensorlayerx/nn/core/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ def construct_graph(inputs, outputs):
526526
def select_attrs(obj):
527527
attrs_dict = obj.__dict__
528528
attrs = {}
529-
_select_key = ['kernel_size', 'stride', 'act', 'padding', 'data_format', 'concat_dim']
529+
_select_key = ['kernel_size', 'stride', 'act', 'padding', 'data_format', 'concat_dim', 'dilation', 'bias']
530530
for k in _select_key:
531531
if k in attrs_dict:
532532
if k == 'act':

0 commit comments

Comments
 (0)