Skip to content

Commit 87f6401

Browse files
authored
Add convert 11 (#12)
* Add resampling * Add merge
1 parent e575613 commit 87f6401

File tree

4 files changed

+184
-33
lines changed

4 files changed

+184
-33
lines changed

README.md

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,55 +25,57 @@ pip install tlx2onnx
2525
## Usage
2626
TLX2ONNX can convert models built using TensorLayerX Module Subclass and Layers, and the Layers support list can be found in [Operator list](OP_LIST.md).
2727

28+
The following is an example of converting a multi-layer perceptron. You can get the code from [here](https://github.com/tensorlayer/TLX2ONNX/tree/main/tests/test_merge.py).
2829
```python
29-
#! /usr/bin/python
30-
# -*- coding: utf-8 -*-
31-
30+
import os
31+
os.environ["TL_BACKEND"] = 'tensorflow'
3232
import tensorlayerx as tlx
3333
from tensorlayerx.nn import Module
34-
from tensorlayerx.nn import Linear, Dropout, Flatten, ReLU6
35-
from tlx2onnx import export
34+
from tensorlayerx.nn import Linear, Concat, Elementwise
35+
from tlx2onnx.main import export
3636
import onnxruntime as rt
3737
import numpy as np
3838

39-
class MLP(Module):
39+
class CustomModel(Module):
4040
def __init__(self):
41-
super(MLP, self).__init__()
42-
# weights init
43-
self.flatten = Flatten()
44-
self.line1 = Linear(in_features=32, out_features=64, act=tlx.nn.LeakyReLU(0.3))
45-
self.d1 = Dropout()
46-
self.line2 = Linear(in_features=64, out_features=128, b_init=None, act=tlx.nn.ReLU)
47-
self.relu6 = ReLU6()
48-
self.line3 = Linear(in_features=128, out_features=10, act=tlx.nn.ReLU)
49-
50-
def forward(self, x):
51-
x = self.flatten(x)
52-
z = self.line1(x)
53-
z = self.d1(z)
54-
z = self.line2(z)
55-
z = self.relu6(z)
56-
z = self.line3(z)
57-
return z
58-
59-
net = MLP()
60-
net.eval()
61-
input = tlx.nn.Input(shape=(3, 2, 2, 8))
62-
onnx_model = export(net, input_spec=input, path='linear_model.onnx')
41+
super(CustomModel, self).__init__(name="custom")
42+
self.linear1 = Linear(in_features=20, out_features=10, act=tlx.ReLU, name='relu1_1')
43+
self.linear2 = Linear(in_features=20, out_features=10, act=tlx.ReLU, name='relu2_1')
44+
self.concat = Concat(concat_dim=1, name='concat_layer')
45+
46+
def forward(self, inputs):
47+
d1 = self.linear1(inputs)
48+
d2 = self.linear2(inputs)
49+
outputs = self.concat([d1, d2])
50+
return outputs
51+
52+
net = CustomModel()
53+
input = tlx.nn.Input(shape=(3, 20), init=tlx.initializers.RandomNormal())
54+
net.set_eval()
55+
output = net(input)
56+
print("tlx out", output)
57+
onnx_model = export(net, input_spec=input, path='concat.onnx')
6358

6459
# Infer Model
65-
sess = rt.InferenceSession('linear_model.onnx')
60+
sess = rt.InferenceSession('concat.onnx')
6661
input_name = sess.get_inputs()[0].name
6762
output_name = sess.get_outputs()[0].name
68-
input_data = tlx.nn.Input(shape=(3, 2, 2, 8))
69-
input_data = np.array(input_data, dtype=np.float32)
63+
input_data = np.array(input, dtype=np.float32)
7064
result = sess.run([output_name], {input_name: input_data})
71-
print(result)
65+
print('onnx out', result)
7266
```
67+
The converted onnx file can be viewed via Netron.
68+
69+
<p align="center"><img src="https://git.openi.org.cn/laich/pose_data/raw/commit/7ac74f03dbfdd8e023cdb205cd415a8571ebb91a/onnxfile.png" width="580"\></p>
70+
71+
72+
The converted results have almost no loss of accuracy.
73+
And the graph show the input and output sizes of each layer, which is very helpful for checking the model.
74+
7375

7476
# Citation
7577

76-
If you find TensorLayerX useful for your project, please cite the following papers:
78+
If you find TensorLayerX or TLX2ONNX useful for your project, please cite the following papers:
7779

7880
```
7981
@article{tensorlayer2017,

tests/test_merge.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
import os
5+
os.environ["TL_BACKEND"] = 'tensorflow'
6+
import tensorlayerx as tlx
7+
from tensorlayerx.nn import Module
8+
from tensorlayerx.nn import Linear, Concat, Elementwise
9+
from tlx2onnx.main import export
10+
import onnxruntime as rt
11+
import numpy as np
12+
13+
class CustomModel(Module):
14+
def __init__(self):
15+
super(CustomModel, self).__init__(name="custom")
16+
self.linear1 = Linear(in_features=20, out_features=10, act=tlx.ReLU, name='relu1_1')
17+
self.linear2 = Linear(in_features=20, out_features=10, act=tlx.ReLU, name='relu2_1')
18+
self.concat = Concat(concat_dim=1, name='concat_layer')
19+
20+
def forward(self, inputs):
21+
d1 = self.linear1(inputs)
22+
d2 = self.linear2(inputs)
23+
outputs = self.concat([d1, d2])
24+
return outputs
25+
26+
net = CustomModel()
27+
input = tlx.nn.Input(shape=(3, 20), init=tlx.initializers.RandomNormal())
28+
net.set_eval()
29+
output = net(input)
30+
print("tlx out", output)
31+
onnx_model = export(net, input_spec=input, path='concat.onnx')
32+
33+
# Infer Model
34+
sess = rt.InferenceSession('concat.onnx')
35+
36+
input_name = sess.get_inputs()[0].name
37+
output_name = sess.get_outputs()[0].name
38+
39+
input_data = np.array(input, dtype=np.float32)
40+
41+
result = sess.run([output_name], {input_name: input_data})
42+
print('onnx out', result)
43+
44+
##################################### Elementwise ###################################################
45+
class CustomModel2(Module):
46+
def __init__(self):
47+
super(CustomModel2, self).__init__(name="custom")
48+
self.linear1 = Linear(in_features=10, out_features=10, act=tlx.ReLU, name='relu1_1')
49+
self.linear2 = Linear(in_features=10, out_features=10, act=tlx.ReLU, name='relu2_1')
50+
self.linear3 = Linear(in_features=10, out_features=10, act=tlx.ReLU, name='relu3_1')
51+
self.element = Elementwise(combine_fn=tlx.matmul, name='concat')
52+
53+
def forward(self, inputs):
54+
d1 = self.linear1(inputs)
55+
d2 = self.linear2(inputs)
56+
d3 = self.linear3(inputs)
57+
outputs = self.element([d1, d2, d3])
58+
return outputs
59+
60+
net = CustomModel2()
61+
input = tlx.nn.Input(shape=(10, 10), init=tlx.initializers.RandomNormal())
62+
net.set_eval()
63+
output = net(input)
64+
print("tlx out", output)
65+
onnx_model2 = export(net, input_spec=input, path='elementwise.onnx')
66+
67+
# Infer Model
68+
sess = rt.InferenceSession('elementwise.onnx')
69+
70+
input_name = sess.get_inputs()[0].name
71+
output_name = sess.get_outputs()[0].name
72+
73+
input_data = np.array(input, dtype=np.float32)
74+
75+
result = sess.run([output_name], {input_name: input_data})
76+
print('onnx out', result)

tlx2onnx/op_mapper/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
from .extend import *
1717
from .rnn import *
1818
from .resampling import *
19+
from .merge import *

tlx2onnx/op_mapper/nn/merge.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
from onnx import helper
5+
from ..op_mapper import OpMapper
6+
from ...common import make_node
7+
from ..datatype_mapping import NP_TYPE_TO_TENSOR_TYPE
8+
import numpy as np
9+
10+
@OpMapper(['Concat'])
11+
class Concat():
12+
# supports v1-v12
13+
14+
@classmethod
15+
def version_1(cls, node, **kwargs):
16+
onnx_node = []
17+
onnx_value = []
18+
onnx_init = []
19+
# get inputs outputs
20+
in_name = node['in_nodes_name']
21+
out_name = node['out_nodes_name'][0]
22+
out_shape = node['out_tensors'][0]
23+
dtype = NP_TYPE_TO_TENSOR_TYPE[node['dtype']]
24+
layer = node['node'].layer
25+
concat_dim = layer.concat_dim
26+
# make concat node
27+
out_v = helper.make_tensor_value_info(out_name, dtype, shape=out_shape)
28+
onnx_value.append(out_v)
29+
out_node, _ = make_node('Concat', inputs=[s for s in in_name], outputs=node['out_nodes_name'], axis=concat_dim)
30+
onnx_node.append(out_node)
31+
return onnx_node, onnx_value, onnx_init
32+
33+
34+
@OpMapper(['Elementwise'])
35+
class Elementwise():
36+
# supports v1-v12
37+
38+
@classmethod
39+
def version_1(cls, node, **kwargs):
40+
onnx_node = []
41+
onnx_value = []
42+
onnx_init = []
43+
# get inputs outputs
44+
in_name = node['in_nodes_name']
45+
out_name = node['out_nodes_name'][0]
46+
out_shape = node['out_tensors'][0]
47+
dtype = NP_TYPE_TO_TENSOR_TYPE[node['dtype']]
48+
layer = node['node'].layer
49+
combine_fn_name = cls.fn_dict(str(layer.combine_fn.__name__))
50+
print(combine_fn_name)
51+
# make combine_fn node
52+
out_v = helper.make_tensor_value_info(out_name, dtype, shape=out_shape)
53+
onnx_value.append(out_v)
54+
55+
out = in_name[0]
56+
for i in np.arange(1, len(in_name)):
57+
if i == len(in_name) - 1:
58+
out_node, out = make_node(combine_fn_name, inputs=[out, in_name[i]], outputs=[out_name])
59+
onnx_node.append(out_node)
60+
else:
61+
out_node, out = make_node(combine_fn_name, inputs=[out, in_name[i]], outputs=[out_name + str(i)])
62+
onnx_node.append(out_node)
63+
return onnx_node, onnx_value, onnx_init
64+
65+
@staticmethod
66+
def fn_dict(fn):
67+
# More operator operations can be added from here.
68+
_dict = {
69+
'matmul': 'MatMul',
70+
'add': 'Add',
71+
}
72+
return _dict[fn]

0 commit comments

Comments
 (0)