@@ -84,8 +84,7 @@ class ZeroPad1d(Module):
8484
8585 Parameters
8686 ----------
87- padding : int, or tuple of 2 ints
88- - If int, zeros to add at the beginning and end of the padding dimension (axis 1).
87+ padding : tuple of 2 ints
8988 - If tuple of 2 ints, zeros to add at the beginning and at the end of the padding dimension.
9089 name : None or str
9190 A unique layer name.
@@ -104,10 +103,12 @@ class ZeroPad1d(Module):
104103 def __init__ (
105104 self ,
106105 padding ,
107- name = None , # 'zeropad1d',
106+ name = None ,
107+ data_format = 'channels_last' ,
108108 ):
109109 super ().__init__ (name )
110110 self .padding = padding
111+ self .data_format = data_format
111112 logging .info ("ZeroPad1d %s: padding: %s" % (self .name , str (padding )))
112113
113114 if not isinstance (self .padding , (int , tuple , dict )):
@@ -124,7 +125,7 @@ def __repr__(self):
124125 return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
125126
126127 def build (self , inputs_shape = None ):
127- self .layer = tlx .ops .ZeroPadding1D (padding = self .padding )
128+ self .layer = tlx .ops .ZeroPadding1D (padding = self .padding , data_format = self . data_format )
128129
129130 def forward (self , inputs ):
130131 outputs = self .layer (inputs )
@@ -141,9 +142,7 @@ class ZeroPad2d(Module):
141142
142143 Parameters
143144 ----------
144- padding : tuple of 2 ints or int, or tuple of 2 tuples of 2 ints.
145- - If int, the same symmetric padding is applied to width and height.
146- - If tuple of 2 ints, interpreted as two different symmetric padding values for height and width as ``(symmetric_height_pad, symmetric_width_pad)``.
145+ padding : tuple of 2 tuples of 2 ints.
147146 - If tuple of 2 tuples of 2 ints, interpreted as ``((top_pad, bottom_pad), (left_pad, right_pad))``.
148147 name : None or str
149148 A unique layer name.
@@ -162,11 +161,12 @@ class ZeroPad2d(Module):
162161 def __init__ (
163162 self ,
164163 padding ,
165- name = None , # 'zeropad2d',
164+ name = None ,
165+ data_format = 'channels_last' ,
166166 ):
167167 super ().__init__ (name )
168-
169168 self .padding = padding
169+ self .data_format = data_format
170170 logging .info ("ZeroPad2d %s: padding: %s" % (self .name , str (self .padding )))
171171
172172 if not isinstance (self .padding , (int , tuple )):
@@ -183,7 +183,7 @@ def __repr__(self):
183183 return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
184184
185185 def build (self , inputs_shape = None ):
186- self .layer = tlx .ops .ZeroPadding2D (padding = self .padding )
186+ self .layer = tlx .ops .ZeroPadding2D (padding = self .padding , data_format = self . data_format )
187187
188188 def forward (self , inputs ):
189189 outputs = self .layer (inputs )
@@ -200,9 +200,7 @@ class ZeroPad3d(Module):
200200
201201 Parameters
202202 ----------
203- padding : int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.
204- - If int, the same symmetric padding is applied to width and height.
205- - If tuple of 2 ints, interpreted as two different symmetric padding values for height and width as ``(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad)``.
203+ padding : tuple of 2 tuples of 2 ints.
206204 - If tuple of 2 tuples of 2 ints, interpreted as ``((left_dim1_pad, right_dim1_pad), (left_dim2_pad, right_dim2_pad), (left_dim3_pad, right_dim3_pad))``.
207205 name : None or str
208206 A unique layer name.
@@ -221,11 +219,12 @@ class ZeroPad3d(Module):
221219 def __init__ (
222220 self ,
223221 padding ,
224- name = None , # 'zeropad3d',
222+ name = None ,
223+ data_format = 'channels_last' ,
225224 ):
226225 super ().__init__ (name )
227226 self .padding = padding
228-
227+ self . data_format = data_format
229228 logging .info ("ZeroPad3d %s: padding: %s" % (self .name , str (self .padding )))
230229
231230 if not isinstance (self .padding , (int , tuple )):
@@ -242,7 +241,7 @@ def __repr__(self):
242241 return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
243242
244243 def build (self , inputs_shape = None ):
245- self .layer = tlx .ops .ZeroPadding3D (padding = self .padding )
244+ self .layer = tlx .ops .ZeroPadding3D (padding = self .padding , data_format = self . data_format )
246245
247246 def forward (self , inputs ):
248247 outputs = self .layer (inputs )
0 commit comments