@@ -431,16 +431,22 @@ class Conv2D(object):
431431
432432 def __init__ (self , strides , padding , data_format = 'NHWC' , dilations = None , out_channel = None , k_size = None ):
433433 self .data_format , self .padding = preprocess_2d_format (data_format , padding )
434- self .ksize = k_size [0 ]
435434 if self .data_format is 'NHWC' :
436- self .dg_stride = strides [1 ]
437- self .dg_dilation = dilations [1 ]
435+ self ._stride = ( strides [1 ], strides [ 2 ])
436+ self ._dilation = ( dilations [1 ], dilations [ 2 ])
438437 elif self .data_format is 'NCHW' :
439- self .dg_stride = strides [2 ]
440- self .dg_dilation = dilations [2 ]
438+ self ._stride = (strides [2 ], strides [3 ])
439+ self ._dilation = (dilations [2 ], dilations [3 ])
440+
441441
442442 def __call__ (self , inputs , filters ):
443- raise NotImplementedError
443+ outputs = F .conv2d (x = inputs ,
444+ weight = filters ,
445+ stride = self ._stride ,
446+ dilation = self ._dilation ,
447+ padding = self .padding ,
448+ data_format = self .data_format )
449+ return outputs
444450
445451
446452def conv2d (input , filters , strides , padding , data_format = 'NCHW' , dilations = None ):
@@ -468,7 +474,20 @@ def conv2d(input, filters, strides, padding, data_format='NCHW', dilations=None)
468474 -------
469475 A Tensor. Has the same type as input.
470476 """
471- raise NotImplementedError
477+ data_format , padding = preprocess_2d_format (data_format , padding )
478+ if data_format is 'NHWC' :
479+ _stride = (strides [1 ], strides [2 ])
480+ _dilation = (dilations [1 ], dilations [2 ])
481+ elif data_format is 'NCHW' :
482+ _stride = (strides [2 ], strides [3 ])
483+ _dilation = (dilations [2 ], dilations [3 ])
484+ outputs = F .conv2d (x = input ,
485+ weight = filters ,
486+ stride = _stride ,
487+ dilation = _dilation ,
488+ padding = padding ,
489+ data_format = data_format )
490+ return outputs
472491
473492
474493class Conv3D (object ):
@@ -577,10 +596,18 @@ class MaxPool(object):
577596 def __init__ (self , ksize , strides , padding , data_format = None ):
578597 self .data_format , self .padding = preprocess_2d_format (data_format , padding )
579598 self .ksize = ksize
580- self .strides = strides
599+ if self .data_format is 'NHWC' :
600+ self ._stride = (strides [1 ], strides [2 ])
601+ elif self .data_format is 'NCHW' :
602+ self ._stride = (strides [2 ], strides [3 ])
581603
582604 def __call__ (self , inputs ):
583- raise NotImplementedError
605+ outputs = F .max_pool2d (x = inputs ,
606+ kernel_size = self .ksize ,
607+ stride = self ._stride ,
608+ padding = self .padding ,
609+ data_format = self .data_format )
610+ return outputs
584611
585612
586613def max_pool (input , ksize , strides , padding , data_format = None ):
@@ -951,7 +978,7 @@ def __init__(self):
951978 pass
952979
953980 def __call__ (self , * args , ** kwargs ):
954- raise NotImplementedError
981+ pd . nn . BatchNorm2D
955982
956983
957984class GroupConv2D (object ):
0 commit comments