@@ -159,12 +159,6 @@ def __repr__(self):
159159
160160 def build (self , inputs_shape = None ):
161161 # https://tensorflow.google.cn/versions/r2.0/api_docs/python/tf/nn/pool
162- if self .data_format == 'channels_last' :
163- self .data_format = 'NWC'
164- elif self .data_format == 'channels_first' :
165- self .data_format = 'NCW'
166- else :
167- raise Exception ("unsupported data format" )
168162 self ._filter_size = [self .kernel_size ]
169163 self ._stride = [self .stride ]
170164 self .max_pool = tlx .ops .MaxPool1d (
@@ -238,12 +232,6 @@ def __repr__(self):
238232
239233 def build (self , inputs_shape = None ):
240234 # https://tensorflow.google.cn/versions/r2.0/api_docs/python/tf/nn/pool
241- if self .data_format == 'channels_last' :
242- self .data_format = 'NWC'
243- elif self .data_format == 'channels_first' :
244- self .data_format = 'NCW'
245- else :
246- raise Exception ("unsupported data format" )
247235 self ._filter_size = [self .kernel_size ]
248236 self ._stride = [self .stride ]
249237 self .avg_pool = tlx .ops .AvgPool1d (
@@ -318,10 +306,8 @@ def __repr__(self):
318306
319307 def build (self , inputs_shape = None ):
320308 if self .data_format == 'channels_last' :
321- self .data_format = 'NHWC'
322309 self ._stride = [1 , self .stride [0 ], self .stride [1 ], 1 ]
323310 elif self .data_format == 'channels_first' :
324- self .data_format = 'NCHW'
325311 self ._stride = [1 , 1 , self .stride [0 ], self .stride [1 ]]
326312 else :
327313 raise Exception ("unsupported data format" )
@@ -398,10 +384,8 @@ def __repr__(self):
398384
399385 def build (self , inputs_shape = None ):
400386 if self .data_format == 'channels_last' :
401- self .data_format = 'NHWC'
402387 self ._stride = [1 , self .stride [0 ], self .stride [1 ], 1 ]
403388 elif self .data_format == 'channels_first' :
404- self .data_format = 'NCHW'
405389 self ._stride = [1 , 1 , self .stride [0 ], self .stride [1 ]]
406390 else :
407391 raise Exception ("unsupported data format" )
@@ -480,10 +464,8 @@ def __repr__(self):
480464
481465 def build (self , inputs_shape = None ):
482466 if self .data_format == 'channels_last' :
483- self .data_format = 'NDHWC'
484467 self ._stride = [1 , self .stride [0 ], self .stride [1 ], self .stride [2 ], 1 ]
485468 elif self .data_format == 'channels_first' :
486- self .data_format = 'NCDHW'
487469 self ._stride = [1 , 1 , self .stride [0 ], self .stride [1 ], self .stride [2 ]]
488470 else :
489471 raise Exception ("unsupported data format" )
@@ -562,12 +544,6 @@ def __repr__(self):
562544
563545 def build (self , inputs_shape = None ):
564546 self ._stride = [1 , self .stride [0 ], self .stride [1 ], self .stride [2 ], 1 ]
565- if self .data_format == 'channels_last' :
566- self .data_format = 'NDHWC'
567- elif self .data_format == 'channels_first' :
568- self .data_format = 'NCDHW'
569- else :
570- raise Exception ("unsupported data format" )
571547 self .avg_pool3d = tlx .ops .AvgPool3d (
572548 ksize = self .kernel_size , strides = self ._stride , padding = self .padding , data_format = self .data_format
573549 )
@@ -1055,12 +1031,6 @@ def __repr__(self):
10551031 return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
10561032
10571033 def build (self , inputs_shape = None ):
1058- if self .data_format == 'channels_last' :
1059- self .data_format = 'NWC'
1060- elif self .data_format == 'channels_first' :
1061- self .data_format = 'NCW'
1062- else :
1063- raise Exception ("unsupported data format" )
10641034
10651035 self .adaptivemeanpool1d = tlx .ops .AdaptiveMeanPool1D (output_size = self .output_size , data_format = self .data_format )
10661036
@@ -1113,12 +1083,6 @@ def __repr__(self):
11131083 return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
11141084
11151085 def build (self , inputs_shape = None ):
1116- if self .data_format == 'channels_last' :
1117- self .data_format = 'NHWC'
1118- elif self .data_format == 'channels_first' :
1119- self .data_format = 'NCHW'
1120- else :
1121- raise Exception ("unsupported data format" )
11221086
11231087 if isinstance (self .output_size , int ):
11241088 self .output_size = (self .output_size , ) * 2
@@ -1174,12 +1138,6 @@ def __repr__(self):
11741138 return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
11751139
11761140 def build (self , inputs_shape = None ):
1177- if self .data_format == 'channels_last' :
1178- self .data_format = 'NDHWC'
1179- elif self .data_format == 'channels_first' :
1180- self .data_format = 'NCDHW'
1181- else :
1182- raise Exception ("unsupported data format" )
11831141
11841142 if isinstance (self .output_size , int ):
11851143 self .output_size = (self .output_size , ) * 3
@@ -1235,12 +1193,6 @@ def __repr__(self):
12351193 return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
12361194
12371195 def build (self , inputs_shape = None ):
1238- if self .data_format == 'channels_last' :
1239- self .data_format = 'NWC'
1240- elif self .data_format == 'channels_first' :
1241- self .data_format = 'NCW'
1242- else :
1243- raise Exception ("unsupported data format" )
12441196
12451197 self .adaptivemaxpool1d = tlx .ops .AdaptiveMaxPool1D (output_size = self .output_size , data_format = self .data_format )
12461198
@@ -1293,12 +1245,7 @@ def __repr__(self):
12931245 return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
12941246
12951247 def build (self , inputs_shape = None ):
1296- if self .data_format == 'channels_last' :
1297- self .data_format = 'NHWC'
1298- elif self .data_format == 'channels_first' :
1299- self .data_format = 'NCHW'
1300- else :
1301- raise Exception ("unsupported data format" )
1248+
13021249 if isinstance (self .output_size , int ):
13031250 self .output_size = (self .output_size , ) * 2
13041251
@@ -1353,12 +1300,6 @@ def __repr__(self):
13531300 return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
13541301
13551302 def build (self , inputs_shape = None ):
1356- if self .data_format == 'channels_last' :
1357- self .data_format = 'NDHWC'
1358- elif self .data_format == 'channels_first' :
1359- self .data_format = 'NCDHW'
1360- else :
1361- raise Exception ("unsupported data format" )
13621303
13631304 if isinstance (self .output_size , int ):
13641305 self .output_size = (self .output_size , ) * 3
0 commit comments