Skip to content

Commit 2b08d24

Browse files
committed
add groupconv2d
1 parent 2b15f75 commit 2b08d24

File tree

3 files changed

+305
-0
lines changed

3 files changed

+305
-0
lines changed

tests/test_groupconv2d.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
import os
5+
# os.environ["TL_BACKEND"] = 'tensorflow'
6+
# os.environ["TL_BACKEND"] = 'paddle'
7+
os.environ["TL_BACKEND"] = 'torch'
8+
# os.environ["TL_BACKEND"] = 'mindspore'
9+
import tensorlayerx as tlx
10+
from tensorlayerx.nn import Module
11+
from tensorlayerx.nn import GroupConv2d
12+
from tlx2onnx.main import export
13+
import onnxruntime as rt
14+
import numpy as np
15+
16+
17+
############################################ test 2d ###########################################################
18+
class CNN(Module):
19+
20+
def __init__(self):
21+
super(CNN, self).__init__()
22+
# weights init
23+
W_init = tlx.nn.initializers.truncated_normal(stddev=5e-2)
24+
b_init2 = tlx.nn.initializers.constant(value=0.1)
25+
self.conv1 = GroupConv2d(
26+
36, (5, 5), (1, 1), n_group=3, padding=(2,2), W_init=W_init, b_init=b_init2, name='conv1',
27+
in_channels=3, data_format='channels_last', act = tlx.nn.ReLU
28+
)
29+
def forward(self, x):
30+
z = self.conv1(x)
31+
return z
32+
33+
net = CNN()
34+
input = tlx.nn.Input(shape=(1, 10, 10, 3))
35+
net.set_eval()
36+
output = net(input)
37+
print("groupconv2d tlx output", output)
38+
onnx_model = export(net, input_spec=input, path='groupconv2d_model.onnx')
39+
40+
# Infer Model
41+
sess = rt.InferenceSession('groupconv2d_model.onnx')
42+
43+
input_name = sess.get_inputs()[0].name
44+
output_name = sess.get_outputs()[0].name
45+
46+
input_data = np.array(input, dtype=np.float32)
47+
48+
result = sess.run([output_name], {input_name: input_data})
49+
print("groupconv2d onnx output", result)

tlx2onnx/op_mapper/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@
2323
from .stack import *
2424
from .subpixelconv import *
2525
from .mask_conv import *
26+
from .group_conv import *
2627

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
import numpy as np
4+
from onnx import helper, numpy_helper
5+
from collections import OrderedDict
6+
import tensorlayerx as tlx
7+
from tlx2onnx.op_mapper.datatype_mapping import NP_TYPE_TO_TENSOR_TYPE
8+
from tlx2onnx.op_mapper.op_mapper import OpMapper
9+
from tlx2onnx.common import make_node
10+
from tlx2onnx.common import make_shape_channels_first, get_channels_first_permutation,tlx_act_2_onnx,get_channels_last_permutation
11+
12+
13+
def convert_padding(padding, input_shape, output_shape, kernel_shape, strides, dilations, spatial, data_format):
14+
if isinstance(padding, str):
15+
if padding == "SAME":
16+
pads = [0] * (spatial * 2)
17+
if data_format == "channels_last":
18+
input_shape = make_shape_channels_first(input_shape)
19+
output_shape = make_shape_channels_first(output_shape)
20+
21+
if any(input_shape[i + 2] == -1 or output_shape[i + 2] == -1 for i in range(spatial)):
22+
23+
auto_pad = "SAME_UPPER"
24+
25+
return auto_pad
26+
27+
for i in range(spatial):
28+
pad = (
29+
(output_shape[i + 2] - 1) * strides[i]
30+
+ dilations[i] * (kernel_shape[i] - 1) + 1
31+
- input_shape[i + 2]
32+
)
33+
pad = max(pad, 0)
34+
pads[i] = pad // 2
35+
pads[i + spatial] = pad - pad // 2
36+
37+
return pads
38+
39+
elif padding == "VALID":
40+
auto_pad = "VALID"
41+
return auto_pad
42+
elif isinstance(padding, int):
43+
pads = [padding] * spatial * 2
44+
return pads
45+
elif isinstance(padding, tuple):
46+
return list(padding) * 2
47+
48+
def convert_w(w, data_format, spatial, w_name):
49+
w = tlx.convert_to_numpy(w)
50+
if tlx.BACKEND == 'tensorflow':
51+
if spatial == 2:
52+
w = np.transpose(w, axes=[3, 2, 0, 1])
53+
elif spatial == 1:
54+
w = np.transpose(w, axes=[2, 1, 0])
55+
elif spatial == 3:
56+
w = np.transpose(w, axes=[4, 3, 0, 1, 2])
57+
return numpy_helper.from_array(w, name=w_name)
58+
elif tlx.BACKEND == 'mindspore':
59+
if spatial == 2 and data_format == 'channels_last':
60+
w = np.transpose(w, axes=[3, 0, 1, 2])
61+
return numpy_helper.from_array(w, name=w_name)
62+
return numpy_helper.from_array(w, name=w_name)
63+
64+
def convert_b(b, b_name):
65+
b = tlx.convert_to_numpy(b)
66+
return numpy_helper.from_array(b, name=b_name)
67+
68+
69+
@OpMapper(["GroupConv2d"])
70+
class Conv():
71+
# suppport v1-v13
72+
73+
@classmethod
74+
def any_version(cls, node, opset, **kwargs):
75+
"""
76+
Parameters
77+
----------
78+
node:node dict {node: node,
79+
in_tensors: node inputs,
80+
out_tensors: node outputs,
81+
in_nodes_name: node inputs name,
82+
out_nodes_name: node outputs name}
83+
Returns
84+
-------
85+
"""
86+
Op_name = 'Conv'
87+
onnx_node, onnx_value, onnx_init = [], [], []
88+
attr_dict = OrderedDict()
89+
90+
#### get data_type
91+
data_type = node['dtype']
92+
tensor_type = NP_TYPE_TO_TENSOR_TYPE[data_type]
93+
#### get in_node_name out_node_nmae
94+
x_name = node['in_nodes_name'][0]
95+
out_name = node['out_nodes_name'][0]
96+
x_shape = node['in_tensors'][0]
97+
out_shape = node['out_tensors'][0]
98+
99+
#### get cur_node_layer node_index
100+
layer = node['node'].layer
101+
layer_type = layer.__class__.__name__
102+
spatial = int(layer_type[-2])
103+
node_name = layer.name
104+
#### get layer_param
105+
layer_param = layer.all_weights
106+
107+
#### get layer_act_type
108+
layer_act = layer.act.__class__.__name__
109+
110+
#### conv inputs
111+
w = None
112+
b = None
113+
if len(layer_param) == 1:
114+
w = layer_param[0]
115+
elif len(layer_param) == 2:
116+
w = layer_param[0]
117+
b = layer_param[1]
118+
119+
#### insert conv attr
120+
kernel_size = node['attr']['kernel_size']
121+
if isinstance(kernel_size, int):
122+
kernel_size = [kernel_size]
123+
attr_dict["kernel_shape"] = kernel_size
124+
dilations = node['attr']['dilation']
125+
if isinstance(dilations, int):
126+
dilations = [dilations,]
127+
attr_dict["dilations"] = dilations
128+
strides = node['attr']['stride']
129+
if isinstance(strides, int):
130+
strides = [strides]
131+
attr_dict["strides"] = strides
132+
data_format = node['attr']['data_format']
133+
paddding = node['attr']['padding']
134+
attr_dict["group"] = layer.n_group
135+
attr_dict["outputs"] = [out_name]
136+
137+
####convert padding
138+
pads = convert_padding(
139+
paddding, x_shape, out_shape, attr_dict["kernel_shape"], attr_dict["strides"],
140+
attr_dict["dilations"], spatial, data_format
141+
)
142+
if isinstance(pads, str):
143+
attr_dict["auto_pad"] = pads
144+
else:
145+
attr_dict["pads"] = pads
146+
147+
if data_format == 'channels_last':
148+
permutation = get_channels_first_permutation(spatial)
149+
x_shape_t = make_shape_channels_first(x_shape)
150+
# insert transpose op: NHWC -> NCHW
151+
transpose_value = helper.make_tensor_value_info(x_name+'_t', tensor_type, shape=x_shape_t)
152+
onnx_value.append(transpose_value)
153+
transpose_node, out = make_node('Transpose', inputs=[x_name], outputs=[x_name+'_t'], perm = permutation)
154+
onnx_node.append(transpose_node)
155+
# convert w
156+
w_name = node_name + '_w'
157+
w_init = convert_w(w, data_format, spatial, w_name)
158+
onnx_init.append(w_init)
159+
attr_dict["inputs"] = [out, w_name]
160+
161+
#### convert b
162+
if b is not None:
163+
b_name = node_name + '_b'
164+
b_init = convert_b(b, b_name)
165+
onnx_init.append(b_init)
166+
attr_dict["inputs"] = [out, w_name, b_name]
167+
168+
attr_dict["outputs"] = [out + "_t"]
169+
conv_node, out = make_node(Op_name, **attr_dict)
170+
onnx_node.append(conv_node)
171+
out_shape_t = make_shape_channels_first(out_shape)
172+
conv_value = helper.make_tensor_value_info(out, tensor_type, shape=out_shape_t)
173+
onnx_value.append(conv_value)
174+
# insert transpose op: NCHW -> NHWC and insert act node
175+
176+
if layer_act != 'NoneType':
177+
act_convert = tlx_act_2_onnx[layer_act]
178+
act_input = out_name + "_act"
179+
act_out = out_name
180+
# insert transpose op
181+
permutation = get_channels_last_permutation(spatial)
182+
transpose_node, out = make_node('Transpose', inputs=[out], outputs=[act_input], perm = permutation)
183+
onnx_node.append(transpose_node)
184+
transpose_value = helper.make_tensor_value_info(act_input, tensor_type, shape = out_shape)
185+
onnx_value.append(transpose_value)
186+
# 如果layer存在act,需要新增一个act node 和 对应act输入的act input info, 并且要更新 conv的outputs 为 act的inputs, 此时act的outputs是整个layer的outputs
187+
act_node, _ = act_convert([out], [act_out])
188+
act_input_value_info = helper.make_tensor_value_info(act_out, tensor_type, out_shape)
189+
onnx_value.append(act_input_value_info)
190+
onnx_node.append(act_node)
191+
return onnx_node, onnx_value, onnx_init
192+
else:
193+
permutation = get_channels_last_permutation(spatial)
194+
transpose_node, out = make_node('Transpose', inputs=[out], outputs=[out_name], perm=permutation)
195+
onnx_node.append(transpose_node)
196+
transpose_value = helper.make_tensor_value_info(out_name, tensor_type, shape=out_shape)
197+
onnx_value.append(transpose_value)
198+
return onnx_node, onnx_value, onnx_init
199+
200+
201+
elif data_format == 'channels_first':
202+
203+
#### convert w
204+
w_name = node_name + '_w'
205+
w_init = convert_w(w, data_format, spatial, w_name)
206+
onnx_init.append(w_init)
207+
attr_dict["inputs"] = [x_name, w_name]
208+
209+
#### convert b
210+
if b is not None:
211+
b_name = node_name + '_b'
212+
b_init = convert_b(b, b_name)
213+
onnx_init.append(b_init)
214+
attr_dict["inputs"] = [x_name, w_name, b_name]
215+
216+
#### make act node
217+
if layer_act != 'NoneType':
218+
act_convert = tlx_act_2_onnx[layer_act]
219+
act_input = out_name + "_act"
220+
act_out = out_name
221+
attr_dict["outputs"] = [act_input]
222+
conv_node, out = make_node(Op_name, **attr_dict)
223+
onnx_node.append(conv_node)
224+
conv_value = helper.make_tensor_value_info(out, tensor_type, shape = out_shape)
225+
onnx_value.append(conv_value)
226+
#insert act node
227+
act_node, out = act_convert([act_input], [act_out])
228+
act_input_value_info = helper.make_tensor_value_info(out, tensor_type, out_shape)
229+
onnx_value.append(act_input_value_info)
230+
onnx_node.append(act_node)
231+
return onnx_node, onnx_value, onnx_init
232+
else:
233+
conv_node, out = make_node(Op_name, **attr_dict)
234+
onnx_node.append(conv_node)
235+
conv_value = helper.make_tensor_value_info(out, tensor_type, out_shape)
236+
onnx_value.append(conv_value)
237+
return onnx_node, onnx_value, onnx_init
238+
else:
239+
raise ValueError("Only support 'channels_first' or 'channels_last' data_format mode, but got {}.".format(data_format))
240+
241+
@classmethod
242+
def version_1(cls, node, **kwargs):
243+
244+
return cls.any_version(node, 1, **kwargs)
245+
246+
247+
@classmethod
248+
def version_11(cls, node, **kwargs):
249+
250+
return cls.any_version( node, 11, **kwargs)
251+
252+
@classmethod
253+
def version_13(cls, node, **kwargs):
254+
255+
return cls.any_version(node, 13, **kwargs)

0 commit comments

Comments
 (0)