@@ -40,6 +40,37 @@ def padding_format(padding):
4040 raise Exception ("Unsupported padding: " + str (padding ))
4141 return padding
4242
43+ def channel_format (data_format , dim = '2d' ):
44+ if dim == '1d' :
45+ if data_format in ["channels_last" , "NWC" , 'NLC' ]:
46+ data_format = "NWC"
47+ elif data_format in ["channels_first" , "NCW" , 'NCL' ]:
48+ data_format = "NCW"
49+ elif data_format == None :
50+ data_format = None
51+ else :
52+ raise Exception ("Unsupported data format: " + str (data_format ))
53+ elif dim == '2d' :
54+ if data_format in ["channels_last" , "NHWC" ]:
55+ data_format = "NHWC"
56+ elif data_format in ["channels_first" , "NCHW" ]:
57+ data_format = "NCHW"
58+ elif data_format == None :
59+ data_format = None
60+ else :
61+ raise Exception ("Unsupported data format: " + str (data_format ))
62+ elif dim == '3d' :
63+ if data_format in ['channels_last' , 'NDHWC' ]:
64+ data_format = 'NDHWC'
65+ elif data_format in ['channels_first' , 'NCDHW' ]:
66+ data_format = 'NCDHW'
67+ elif data_format == None :
68+ data_format = None
69+ else :
70+ raise Exception ("Unsupported data format: " + str (data_format ))
71+ else :
72+ raise Exception ("dim must be '1d', '2d', '3d'." )
73+ return data_format
4374
4475def preprocess_padding (padding , dim = '2d' , data_format = 'NHWC' ):
4576 # When explicit padding is used and data_format is "NHWC",
@@ -88,8 +119,6 @@ def check_padding(padding, dim='2d'):
88119 raise RuntimeError ("expected padding to be a single integer value or a list of 3 values to match the convolution dimensions." )
89120
90121
91-
92-
93122def preprocess_1d_format (data_format , padding ):
94123 """
95124 Checks that the 1-D dataformat format correspond format.
@@ -105,14 +134,7 @@ def preprocess_1d_format(data_format, padding):
105134 -------
106135 str "NWC" or "NCW" and "SAME" or "VALID"
107136 """
108- if data_format in ["channels_last" , "NWC" , 'NLC' ]:
109- data_format = "NWC"
110- elif data_format in ["channels_first" , "NCW" , 'NCL' ]:
111- data_format = "NCW"
112- elif data_format == None :
113- data_format = None
114- else :
115- raise Exception ("Unsupported data format: " + str (data_format ))
137+ data_format = channel_format (data_format , dim = '1d' )
116138 padding = padding_format (padding )
117139 return data_format , padding
118140
@@ -133,14 +155,7 @@ def preprocess_2d_format(data_format, padding):
133155 str "NHWC" or "NCHW" and "SAME" or "VALID"
134156 """
135157
136- if data_format in ["channels_last" , "NHWC" ]:
137- data_format = "NHWC"
138- elif data_format in ["channels_first" , "NCHW" ]:
139- data_format = "NCHW"
140- elif data_format == None :
141- data_format = None
142- else :
143- raise Exception ("Unsupported data format: " + str (data_format ))
158+ data_format = channel_format (data_format , dim = '2d' )
144159 padding = padding_format (padding )
145160 return data_format , padding
146161
@@ -161,14 +176,7 @@ def preprocess_3d_format(data_format, padding):
161176 str "NDHWC" or "NCDHW" and "SAME" or "VALID"
162177 """
163178
164- if data_format in ['channels_last' , 'NDHWC' ]:
165- data_format = 'NDHWC'
166- elif data_format in ['channels_first' , 'NCDHW' ]:
167- data_format = 'NCDHW'
168- elif data_format == None :
169- data_format = None
170- else :
171- raise Exception ("Unsupported data format: " + str (data_format ))
179+ data_format = channel_format (data_format , dim = '3d' )
172180 padding = padding_format (padding )
173181 return data_format , padding
174182
@@ -868,10 +876,11 @@ def __init__(self, ksize, strides, padding, data_format=None):
868876 self .padding = "VALID"
869877
870878 def __call__ (self , inputs ):
879+ data_format = channel_format (self .data_format , str (len (inputs .shape ) - 2 ) + 'd' )
871880 if self .padding_value is not None :
872881 inputs = tf .pad (inputs , self .padding_value )
873882 outputs = tf .nn .avg_pool (
874- input = inputs , ksize = self .ksize , strides = self .strides , padding = self .padding , data_format = self . data_format
883+ input = inputs , ksize = self .ksize , strides = self .strides , padding = self .padding , data_format = data_format
875884 )
876885 return outputs
877886
0 commit comments