@@ -34,11 +34,62 @@ def padding_format(padding):
3434 padding = "VALID"
3535 elif padding == None :
3636 padding = None
37+ elif isinstance (padding , tuple ) or isinstance (padding , int ):
38+ return padding
3739 else :
3840 raise Exception ("Unsupported padding: " + str (padding ))
3941 return padding
4042
4143
44+ def preprocess_padding (padding , dim = '2d' , data_format = 'NHWC' ):
45+ # When explicit padding is used and data_format is "NHWC",
46+ # this should be in the form [[0, 0], [pad_top, pad_bottom],[pad_left, pad_right], [0, 0]].
47+ # When explicit padding used and data_format is "NCHW",
48+ # this should be in the form [[0, 0], [0, 0],[pad_top, pad_bottom], [pad_left, pad_right]].
49+ check_padding (padding , dim )
50+ if dim == '1d' :
51+ if data_format == 'NWC' :
52+ out_padding = [[0 , 0 ], [padding , padding ], [0 , 0 ]]
53+ else :
54+ out_padding = [[0 , 0 ], [0 , 0 ], [padding , padding ]]
55+ elif dim == '2d' :
56+ if isinstance (padding , int ):
57+ if data_format == 'NHWC' :
58+ out_padding = [[0 , 0 ], [padding , padding ], [padding , padding ], [0 , 0 ]]
59+ else :
60+ out_padding = [[0 , 0 ], [0 , 0 ],[padding , padding ], [padding , padding ]]
61+ elif isinstance (padding , tuple ):
62+ if data_format == 'NHWC' :
63+ out_padding = [[0 , 0 ], [padding [0 ], padding [0 ]], [padding [1 ], padding [1 ]], [0 , 0 ]]
64+ else :
65+ out_padding = [[0 , 0 ], [0 , 0 ],[padding [0 ], padding [0 ]], [padding [1 ], padding [1 ]]]
66+ elif dim == '3d' :
67+ if isinstance (padding , int ):
68+ if data_format == 'NDHWC' :
69+ out_padding = [[0 , 0 ], [padding , padding ], [padding , padding ], [padding , padding ], [0 , 0 ]]
70+ else :
71+ out_padding = [[0 , 0 ], [0 , 0 ], [padding , padding ], [padding , padding ], [padding , padding ]]
72+ elif isinstance (padding , tuple ):
73+ if data_format == 'NDHWC' :
74+ out_padding = [[0 , 0 ], [padding [0 ], padding [0 ]], [padding [1 ], padding [1 ]], [padding [2 ], padding [2 ]], [0 , 0 ]]
75+ else :
76+ out_padding = [[0 , 0 ], [0 , 0 ], [padding [0 ], padding [0 ]], [padding [1 ], padding [1 ]], [padding [2 ], padding [2 ]]]
77+ else :
78+ raise RuntimeError ("Unsupported input dimensions." )
79+ return out_padding
80+
81+
82+ def check_padding (padding , dim = '2d' ):
83+ if dim == '1d' and isinstance (object , tuple ):
84+ raise RuntimeError ("expected padding to be a single integer value or a list of 1 values to match the convolution dimensions." )
85+ if dim == '2d' and isinstance (object , tuple ) and len (padding ) > 2 :
86+ raise RuntimeError ("expected padding to be a single integer value or a list of 2 values to match the convolution dimensions." )
87+ if dim == '3d' and isinstance (object , tuple ) and len (padding ) > 3 :
88+ raise RuntimeError ("expected padding to be a single integer value or a list of 3 values to match the convolution dimensions." )
89+
90+
91+
92+
4293def preprocess_1d_format (data_format , padding ):
4394 """
4495 Checks that the 1-D dataformat format correspond format.
@@ -438,8 +489,16 @@ def __init__(self, stride, padding, data_format='NWC', dilations=None, out_chann
438489 self .stride = stride
439490 self .dilations = dilations
440491 self .data_format , self .padding = preprocess_1d_format (data_format , padding )
492+ self .pad_value = None
493+
494+ if isinstance (padding , int ):
495+ self .pad_value = preprocess_padding (self .padding , '1d' , self .data_format )
496+ self .padding = 'VALID'
497+
441498
442499 def __call__ (self , input , filters ):
500+ if self .pad_value is not None :
501+ input = tf .pad (input , paddings = self .pad_value )
443502 outputs = tf .nn .conv1d (
444503 input = input ,
445504 filters = filters ,
@@ -501,6 +560,9 @@ def __init__(self, strides, padding, data_format='NHWC', dilations=None, out_cha
501560 self .dilations = dilations
502561 self .data_format , self .padding = preprocess_2d_format (data_format , padding )
503562
563+ if isinstance (padding , int ) or isinstance (padding , tuple ):
564+ self .padding = preprocess_padding (self .padding , '2d' , self .data_format )
565+
504566 def __call__ (self , input , filters ):
505567 outputs = tf .nn .conv2d (
506568 input = input ,
@@ -559,8 +621,15 @@ def __init__(self, strides, padding, data_format='NDHWC', dilations=None, out_ch
559621 self .strides = strides
560622 self .dilations = dilations
561623 self .data_format , self .padding = preprocess_3d_format (data_format , padding )
624+ self .pad_value = None
625+
626+ if isinstance (padding , int ) or isinstance (padding , tuple ):
627+ self .pad_value = preprocess_padding (self .padding , '3d' , self .data_format )
628+ self .padding = 'VALID'
562629
563630 def __call__ (self , input , filters ):
631+ if self .pad_value is not None :
632+ input = tf .pad (input , paddings = self .pad_value )
564633 outputs = tf .nn .conv3d (
565634 input = input ,
566635 filters = filters ,
@@ -676,8 +745,14 @@ def __init__(self, ksize, strides, padding, data_format=None):
676745 self .data_format , self .padding = preprocess_1d_format (data_format = data_format , padding = padding )
677746 self .ksize = ksize
678747 self .strides = strides
748+ self .padding_value = None
749+ if not isinstance (self .padding , str ):
750+ self .padding_value = preprocess_padding (self .padding , '1d' , self .data_format )
751+ self .padding = "VALID"
679752
680753 def __call__ (self , inputs ):
754+ if self .padding_value is not None :
755+ inputs = tf .pad (inputs , self .padding_value )
681756 outputs = tf .nn .max_pool (
682757 input = inputs , ksize = self .ksize , strides = self .strides , padding = self .padding , data_format = self .data_format
683758 )
@@ -695,10 +770,22 @@ def __init__(self, ksize, strides, padding, data_format=None):
695770 def __call__ (self , inputs ):
696771 if len (inputs .shape ) == 3 :
697772 self .data_format , self .padding = preprocess_1d_format (data_format = self .data_format , padding = self .padding )
773+ if not isinstance (self .padding , str ):
774+ self .padding_value = preprocess_padding (self .padding , '1d' , self .data_format )
775+ self .padding = "VALID"
776+ inputs = tf .pad (inputs , self .padding_value )
698777 elif len (inputs .shape ) == 4 :
699778 self .data_format , self .padding = preprocess_2d_format (data_format = self .data_format , padding = self .padding )
779+ if not isinstance (self .padding , str ):
780+ self .padding_value = preprocess_padding (self .padding , '2d' , self .data_format )
781+ self .padding = "VALID"
782+ inputs = tf .pad (inputs , self .padding_value )
700783 elif len (inputs .shape ) == 5 :
701784 self .data_format , self .padding = preprocess_3d_format (data_format = self .data_format , padding = self .padding )
785+ if not isinstance (self .padding , str ):
786+ self .padding_value = preprocess_padding (self .padding , '3d' , self .data_format )
787+ self .padding = "VALID"
788+ inputs = tf .pad (inputs , self .padding_value )
702789
703790 outputs = tf .nn .max_pool (
704791 input = inputs , ksize = self .ksize , strides = self .strides , padding = self .padding , data_format = self .data_format
@@ -749,8 +836,14 @@ def __init__(self, ksize, strides, padding, data_format=None):
749836 self .data_format , self .padding = preprocess_1d_format (data_format = data_format , padding = padding )
750837 self .ksize = ksize
751838 self .strides = strides
839+ self .padding_value = None
840+ if not isinstance (self .padding , str ):
841+ self .padding_value = preprocess_padding (self .padding , '1d' , self .data_format )
842+ self .padding = "VALID"
752843
753844 def __call__ (self , inputs ):
845+ if self .padding_value is not None :
846+ inputs = tf .pad (inputs , self .padding_value )
754847 outputs = tf .nn .pool (
755848 input = inputs ,
756849 window_shape = self .ksize ,
@@ -769,8 +862,14 @@ def __init__(self, ksize, strides, padding, data_format=None):
769862 self .strides = strides
770863 self .data_format = data_format
771864 self .padding = padding_format (padding )
865+ self .padding_value = None
866+ if not isinstance (self .padding , str ):
867+ self .padding_value = preprocess_padding (self .padding , '2d' , self .data_format )
868+ self .padding = "VALID"
772869
773870 def __call__ (self , inputs ):
871+ if self .padding_value is not None :
872+ inputs = tf .pad (inputs , self .padding_value )
774873 outputs = tf .nn .avg_pool (
775874 input = inputs , ksize = self .ksize , strides = self .strides , padding = self .padding , data_format = self .data_format
776875 )
@@ -819,8 +918,14 @@ def __init__(self, ksize, strides, padding, data_format=None):
819918 self .data_format , self .padding = preprocess_3d_format (data_format , padding )
820919 self .ksize = ksize
821920 self .strides = strides
921+ self .padding_value = None
922+ if not isinstance (self .padding , str ):
923+ self .padding_value = preprocess_padding (self .padding , '3d' , self .data_format )
924+ self .padding = "VALID"
822925
823926 def __call__ (self , inputs ):
927+ if self .padding_value is not None :
928+ inputs = tf .pad (inputs , self .padding_value )
824929 outputs = tf .nn .max_pool3d (
825930 input = inputs ,
826931 ksize = self .ksize ,
@@ -876,8 +981,14 @@ def __init__(self, ksize, strides, padding, data_format=None):
876981 self .data_format , self .padding = preprocess_3d_format (data_format , padding )
877982 self .ksize = ksize
878983 self .strides = strides
984+ self .padding_value = None
985+ if not isinstance (self .padding , str ):
986+ self .padding_value = preprocess_padding (self .padding , '3d' , self .data_format )
987+ self .padding = "VALID"
879988
880989 def __call__ (self , inputs ):
990+ if self .padding_value is not None :
991+ inputs = tf .pad (inputs , self .padding_value )
881992 outputs = tf .nn .avg_pool3d (
882993 input = inputs ,
883994 ksize = self .ksize ,
0 commit comments