@@ -324,17 +324,19 @@ def __repr__(self):
324324 return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
325325
326326 def build (self , inputs_shape = None ):
327- self ._strides = [1 , self .strides [0 ], self .strides [1 ], 1 ]
328327 if self .data_format == 'channels_last' :
328+ self ._strides = [1 , self .strides [0 ], self .strides [1 ], 1 ]
329329 self .data_format = 'NHWC'
330330 elif self .data_format == 'channels_first' :
331331 self .data_format = 'NCHW'
332+ self ._strides = [1 , 1 , self .strides [0 ], self .strides [1 ]]
332333 else :
333334 raise Exception ("unsupported data format" )
334335
335336 def forward (self , inputs ):
336337 outputs = tf .nn .max_pool (
337- input = inputs , ksize = self .filter_size , strides = self ._strides , padding = self .padding , name = self .name
338+ input = inputs , ksize = self .filter_size , strides = self ._strides , padding = self .padding , name = self .name ,
339+ data_format = self .data_format
338340 )
339341 return outputs
340342
@@ -397,17 +399,19 @@ def __repr__(self):
397399 return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
398400
399401 def build (self , inputs_shape = None ):
400- self ._strides = [1 , self .strides [0 ], self .strides [1 ], 1 ]
401402 if self .data_format == 'channels_last' :
402403 self .data_format = 'NHWC'
404+ self ._strides = [1 , self .strides [0 ], self .strides [1 ], 1 ]
403405 elif self .data_format == 'channels_first' :
404406 self .data_format = 'NCHW'
407+ self ._strides = [1 , 1 , self .strides [0 ], self .strides [1 ]]
405408 else :
406409 raise Exception ("unsupported data format" )
407410
408411 def forward (self , inputs ):
409412 outputs = tf .nn .avg_pool (
410- input = inputs , ksize = self .filter_size , strides = self ._strides , padding = self .padding , name = self .name
413+ input = inputs , ksize = self .filter_size , strides = self ._strides , padding = self .padding , name = self .name ,
414+ data_format = self .data_format
411415 )
412416 return outputs
413417
@@ -473,11 +477,12 @@ def __repr__(self):
473477 return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
474478
475479 def build (self , inputs_shape = None ):
476- self ._strides = [1 , self .strides [0 ], self .strides [1 ], self .strides [2 ], 1 ]
477480 if self .data_format == 'channels_last' :
478481 self .data_format = 'NDHWC'
482+ self ._strides = [1 , self .strides [0 ], self .strides [1 ], self .strides [2 ], 1 ]
479483 elif self .data_format == 'channels_first' :
480484 self .data_format = 'NCDHW'
485+ self ._strides = [1 , 1 , self .strides [0 ], self .strides [1 ], self .strides [2 ]]
481486 else :
482487 raise Exception ("unsupported data format" )
483488
@@ -554,11 +559,12 @@ def __repr__(self):
554559 return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
555560
556561 def build (self , inputs_shape = None ):
557- self ._strides = [1 , self .strides [0 ], self .strides [1 ], self .strides [2 ], 1 ]
558562 if self .data_format == 'channels_last' :
559563 self .data_format = 'NDHWC'
564+ self ._strides = [1 , self .strides [0 ], self .strides [1 ], self .strides [2 ], 1 ]
560565 elif self .data_format == 'channels_first' :
561566 self .data_format = 'NCDHW'
567+ self ._strides = [1 , 1 , self .strides [0 ], self .strides [1 ], self .strides [2 ]]
562568 else :
563569 raise Exception ("unsupported data format" )
564570
0 commit comments