Skip to content

Commit 2b15f75

Browse files
authored
Add maskconv (#16)
1 parent 23543f2 commit 2b15f75

File tree

4 files changed

+171
-2
lines changed

4 files changed

+171
-2
lines changed

tests/test_mask_conv.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
import os
5+
os.environ["TL_BACKEND"] = 'tensorflow'
6+
import tensorlayerx as tlx
7+
from tensorlayerx.nn import Module
8+
from tensorlayerx.nn import MaskedConv3d
9+
from tlx2onnx.main import export
10+
import onnxruntime as rt
11+
import numpy as np
12+
13+
14+
class MLP(Module):
15+
def __init__(self):
16+
super(MLP, self).__init__()
17+
self.mask_conv = MaskedConv3d(mask_type='B', out_channels=32, kernel_size=(1, 1, 1), stride=(2, 2, 2), act=tlx.ReLU, name='conv3d_2',
18+
in_channels=3, padding='SAME')
19+
20+
def forward(self, x):
21+
x = self.mask_conv(x)
22+
return x
23+
24+
net = MLP()
25+
input = tlx.nn.Input(shape=(5, 10, 10, 10, 3))
26+
net.set_eval()
27+
output = net(input)
28+
print("tlx out", output)
29+
onnx_model = export(net, input_spec=input, path='maskconv.onnx')
30+
31+
# Infer Model
32+
sess = rt.InferenceSession('maskconv.onnx')
33+
34+
input_name = sess.get_inputs()[0].name
35+
output_name = sess.get_outputs()[0].name
36+
37+
input_data = np.array(input, dtype=np.float32)
38+
39+
result = sess.run([output_name], {input_name: input_data})
40+
print('onnx out', result)

tlx2onnx/op_mapper/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@
2222
from .scale import *
2323
from .stack import *
2424
from .subpixelconv import *
25+
from .mask_conv import *
2526

tlx2onnx/op_mapper/nn/deconv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ def version_1(cls, node, **kwargs):
3535

3636
if data_format == 'channels_last':
3737
# channels last conver weights and input
38-
x_shape = make_shape_channels_first(x_shape)
38+
x_temp_shape = make_shape_channels_first(x_shape)
3939
out_temp_shape = make_shape_channels_first(out_shape)
4040
weights = convert_w(weights_value, data_format, spatial, y)
4141
onnx_init.append(weights)
42-
t_x = helper.make_tensor_value_info(node['in_nodes_name'][0] + 't', NP_TYPE_TO_TENSOR_TYPE[node['dtype']], shape=x_shape)
42+
t_x = helper.make_tensor_value_info(node['in_nodes_name'][0] + 't', NP_TYPE_TO_TENSOR_TYPE[node['dtype']], shape=x_temp_shape)
4343
onnx_value.append(t_x)
4444
tx_node, x = make_node('Transpose', inputs=[x], outputs=[node['in_nodes_name'][0] + 't'], perm=get_channels_first_permutation(spatial))
4545
onnx_node.append(tx_node)

tlx2onnx/op_mapper/nn/mask_conv.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
from onnx import helper, numpy_helper
5+
from ..op_mapper import OpMapper
6+
from ...common import make_node, to_numpy
7+
from ..datatype_mapping import NP_TYPE_TO_TENSOR_TYPE
8+
from ...common import tlx_act_2_onnx, convert_padding, make_shape_channels_first, convert_w, \
9+
get_channels_last_permutation, get_channels_first_permutation
10+
11+
@OpMapper(['MaskedConv3d'])
12+
class MaskedConv3d():
13+
# supports v1-v12
14+
15+
@classmethod
16+
def version_1(cls, node, **kwargs):
17+
onnx_node = []
18+
onnx_value = []
19+
onnx_init = []
20+
21+
x = node['in_nodes_name'][0]
22+
x_shape = node['in_tensors'][0]
23+
out_shape = node['out_tensors'][0]
24+
spatial = int(node['node'].layer.__class__.__name__[-2])
25+
26+
# make weights
27+
y = node['node'].layer.name + '/kernel'
28+
weights_value = node['node'].layer.masked_kernel
29+
30+
attr_dict = {}
31+
attr_dict['dilations'] = dilations = node['attr']['dilation']
32+
attr_dict['kernel_shape'] = kernel_shape = node['attr']['kernel_size']
33+
attr_dict['strides'] = strides = node['attr']['stride']
34+
pads = node['attr']['padding']
35+
data_format = node['attr']['data_format']
36+
37+
if data_format == 'channels_last':
38+
# channels last conver weights and input
39+
x_shape_temp = make_shape_channels_first(x_shape)
40+
out_temp_shape = make_shape_channels_first(out_shape)
41+
weights = convert_w(weights_value, data_format, spatial, y)
42+
onnx_init.append(weights)
43+
t_x = helper.make_tensor_value_info(node['in_nodes_name'][0] + 't', NP_TYPE_TO_TENSOR_TYPE[node['dtype']], shape=x_shape_temp)
44+
onnx_value.append(t_x)
45+
tx_node, x = make_node('Transpose', inputs=[x], outputs=[node['in_nodes_name'][0] + 't'], perm=get_channels_first_permutation(spatial))
46+
onnx_node.append(tx_node)
47+
else:
48+
# Build weights
49+
weights = numpy_helper.from_array(arr=to_numpy(weights_value), name=y)
50+
onnx_init.append(weights)
51+
52+
# Build padding
53+
pads = convert_padding(
54+
pads, x_shape, out_shape, kernel_shape, strides,
55+
dilations, spatial, data_format
56+
)
57+
if isinstance(pads, str):
58+
attr_dict["auto_pad"] = pads
59+
else:
60+
attr_dict["pads"] = pads
61+
62+
if node['node'].layer.b_init is not None:
63+
b = numpy_helper.from_array(arr=to_numpy(node['node'].layer.bias), name=node['node'].layer.name + '/b')
64+
onnx_init.append(b)
65+
b_name = node['node'].layer.name + '/b'
66+
input_list = [x, y, b_name]
67+
else:
68+
input_list = [x, y]
69+
70+
if data_format == 'channels_first':
71+
if node['node'].layer.act is not None:
72+
# Build Conv3d
73+
de_v = helper.make_tensor_value_info(node['out_nodes_name'][0] + 'de', NP_TYPE_TO_TENSOR_TYPE[node['dtype']],
74+
shape=out_shape)
75+
onnx_value.append(de_v)
76+
ct_node, out = make_node('Conv', inputs=input_list,
77+
outputs=[node['out_nodes_name'][0] + 'de'], **attr_dict)
78+
onnx_node.append(ct_node)
79+
80+
act_op = node['node'].layer.act.__class__.__name__
81+
out_v = helper.make_tensor_value_info(node['out_nodes_name'][0], NP_TYPE_TO_TENSOR_TYPE[node['dtype']],
82+
shape=out_shape)
83+
onnx_value.append(out_v)
84+
# Using Opmapper
85+
act_node, _ = tlx_act_2_onnx[act_op]([out], node['out_nodes_name'], node['node'].layer.act)
86+
onnx_node.append(act_node)
87+
else:
88+
out_v = helper.make_tensor_value_info(node['out_nodes_name'][0], NP_TYPE_TO_TENSOR_TYPE[node['dtype']],
89+
shape=out_shape) #
90+
onnx_value.append(out_v)
91+
ct_node, out = make_node('Conv', inputs=input_list,
92+
outputs=node['out_nodes_name'], **attr_dict)
93+
onnx_node.append(ct_node)
94+
elif data_format == 'channels_last':
95+
if node['node'].layer.act is not None:
96+
# Build Conv
97+
ct_v = helper.make_tensor_value_info(node['out_nodes_name'][0] + 'ct', NP_TYPE_TO_TENSOR_TYPE[node['dtype']],
98+
shape=out_temp_shape)
99+
onnx_value.append(ct_v)
100+
ct_node, out = make_node('Conv', inputs=input_list,
101+
outputs=[node['out_nodes_name'][0] + 'ct'], **attr_dict)
102+
onnx_node.append(ct_node)
103+
104+
act_op = node['node'].layer.act.__class__.__name__
105+
act_v = helper.make_tensor_value_info(node['out_nodes_name'][0] + 'a', NP_TYPE_TO_TENSOR_TYPE[node['dtype']],
106+
shape=out_temp_shape)
107+
onnx_value.append(act_v)
108+
# Using Opmapper
109+
act_node, out = tlx_act_2_onnx[act_op]([out], [node['out_nodes_name'][0] + 'a'], node['node'].layer.act)
110+
onnx_node.append(act_node)
111+
else:
112+
out_v = helper.make_tensor_value_info(node['out_nodes_name'][0] + 'ct', NP_TYPE_TO_TENSOR_TYPE[node['dtype']],
113+
shape=out_temp_shape)
114+
onnx_value.append(out_v)
115+
o_node, out = make_node('Conv', inputs=input_list,
116+
outputs=[node['out_nodes_name'][0] + 'ct'], **attr_dict)
117+
onnx_node.append(o_node)
118+
119+
t_out = helper.make_tensor_value_info(node['out_nodes_name'][0], NP_TYPE_TO_TENSOR_TYPE[node['dtype']], shape=out_shape)
120+
onnx_value.append(t_out)
121+
tout_node, _ = make_node('Transpose', inputs=[out], outputs=node['out_nodes_name'], perm=get_channels_last_permutation(spatial))
122+
onnx_node.append(tout_node)
123+
else:
124+
raise ValueError("Only support 'channels_first' or 'channels_last' data_format mode, but got {}.".format(data_format))
125+
126+
return onnx_node, onnx_value, onnx_init
127+
128+

0 commit comments

Comments
 (0)