Skip to content

Commit d35bb36

Browse files
authored
Add zeropad layers (#13)
1 parent 60dad1f commit d35bb36

File tree

2 files changed

+61
-5
lines changed

2 files changed

+61
-5
lines changed

tests/test_padding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ class CustomModel(Module):
1414
def __init__(self):
1515
super(CustomModel, self).__init__(name="custom")
1616
self.pad = PadLayer([[1, 2], [3, 4], [5, 6], [7, 8]], "REFLECT", name='inpad')
17+
self.pad2d = ZeroPad2d(padding=((2, 2), (3, 3)), data_format='channels_last')
1718

1819
def forward(self, inputs):
1920
x = self.pad(inputs)
21+
x = self.pad2d(x)
2022
return x
2123

2224
net = CustomModel()

tlx2onnx/op_mapper/nn/padding.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def version_1(cls, node, **kwargs):
2222
layer = node['node'].layer
2323
# get attrs
2424
value = np.array(layer.constant_values).astype(node['dtype'])
25-
c_value = numpy_helper.from_array(value, name='value')
25+
c_value = numpy_helper.from_array(value, name=layer.name + 'value')
2626
onnx_init.append(c_value)
2727
# processing mode
2828
mode_dict = {"CONSTANT": 'constant', "REFLECT": 'reflect',"SYMMETRIC": 'edge'}
@@ -39,17 +39,71 @@ def version_1(cls, node, **kwargs):
3939
for i in range(len(pads_temp) // 2):
4040
pads.append(pads_temp[i*2+1])
4141
pads = np.array(pads).astype(np.int64)
42-
p_value = numpy_helper.from_array(pads, name='pads')
42+
p_value = numpy_helper.from_array(pads, name=layer.name + 'pads')
4343
onnx_init.append(p_value)
4444
# make nodes
4545
v_out = helper.make_tensor_value_info(out_name, dtype, shape=out_shape)
4646
onnx_value.append(v_out)
4747

4848
if mode == 'constant':
49-
p_node, out = make_node('Pad', inputs=[in_name, 'pads', 'value'], outputs=[out_name], mode='constan')
49+
p_node, out = make_node('Pad', inputs=[in_name, layer.name + 'pads', layer.name + 'value'], outputs=[out_name], mode='constant')
5050
onnx_node.append(p_node)
5151
else:
52-
p_node, out = make_node('Pad', inputs=[in_name, 'pads'], outputs=[out_name], mode=mode)
52+
p_node, out = make_node('Pad', inputs=[in_name, layer.name + 'pads'], outputs=[out_name], mode=mode)
5353
onnx_node.append(p_node)
5454

55-
return onnx_node, onnx_value, onnx_init
55+
return onnx_node, onnx_value, onnx_init
56+
57+
58+
@OpMapper(['ZeroPad1d', 'ZeroPad2d', 'ZeroPad3d'])
59+
class ZeroPad():
60+
# supports v1-v12
61+
62+
@classmethod
63+
def version_1(cls, node, **kwargs):
64+
onnx_node, onnx_value, onnx_init = [], [], []
65+
# get inputs outputs
66+
in_name = node['in_nodes_name'][0]
67+
out_name = node['out_nodes_name'][0]
68+
out_shape = node['out_tensors'][0]
69+
dtype = NP_TYPE_TO_TENSOR_TYPE[node['dtype']]
70+
layer = node['node'].layer
71+
# get attrs
72+
padding = layer.padding
73+
data_format = layer.data_format
74+
pads_temp = convert_padding(padding, data_format)
75+
76+
pads = []
77+
for i in range(len(pads_temp)//2):
78+
pads.append(pads_temp[2*i])
79+
for i in range(len(pads_temp) // 2):
80+
pads.append(pads_temp[i*2+1])
81+
pads = np.array(pads).astype(np.int64)
82+
83+
p_value = numpy_helper.from_array(pads, name=layer.name + 'pads')
84+
onnx_init.append(p_value)
85+
86+
# make nodes
87+
v_out = helper.make_tensor_value_info(out_name, dtype, shape=out_shape)
88+
onnx_value.append(v_out)
89+
p_node, out = make_node('Pad', inputs=[in_name, layer.name + 'pads'], outputs=[out_name], mode='constant')
90+
onnx_node.append(p_node)
91+
92+
return onnx_node, onnx_value, onnx_init
93+
94+
95+
def convert_padding(padding, data_format):
96+
if np.size(padding) == 2:
97+
if data_format == 'channels_first':
98+
out = (0, 0, 0, 0) + padding
99+
else:
100+
out = (0, 0) + padding + (0, 0)
101+
else:
102+
pads_temp = padding[0]
103+
for i in np.arange(1, len(padding)):
104+
pads_temp += padding[i]
105+
if data_format == 'channels_first':
106+
out = (0, 0, 0, 0) + pads_temp
107+
else:
108+
out = (0, 0) + pads_temp + (0, 0)
109+
return out

0 commit comments

Comments
 (0)