Skip to content

Commit 472d4fd

Browse files
committed
convert SRGAN model test
1 parent 2b08d24 commit 472d4fd

File tree

7 files changed

+470
-121
lines changed

7 files changed

+470
-121
lines changed

tests/export_application/test_export_srgan.py

Lines changed: 424 additions & 0 deletions
Large diffs are not rendered by default.

tlx2onnx/common/preprocessing.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ def convert_padding(padding, input_shape, output_shape, kernel_shape, strides, d
2323
output_shape = make_shape_channels_first(output_shape)
2424

2525
if any(input_shape[i + 2] == -1 or output_shape[i + 2] == -1 for i in range(spatial)):
26-
return "SAME_UPPER"
26+
27+
auto_pad = "SAME_UPPER"
28+
29+
return auto_pad
2730

2831
for i in range(spatial):
2932
pad = (
@@ -38,14 +41,14 @@ def convert_padding(padding, input_shape, output_shape, kernel_shape, strides, d
3841
return pads
3942

4043
elif padding == "VALID":
41-
return "VALID"
44+
auto_pad = "VALID"
45+
return auto_pad
4246
elif isinstance(padding, int):
4347
pads = [padding] * spatial * 2
4448
return pads
4549
elif isinstance(padding, tuple):
4650
return list(padding) * 2
4751

48-
4952
def convert_w(w, data_format, spatial, w_name):
5053
w = tlx.convert_to_numpy(w)
5154
if tlx.BACKEND == 'tensorflow':
@@ -62,6 +65,9 @@ def convert_w(w, data_format, spatial, w_name):
6265
return numpy_helper.from_array(w, name=w_name)
6366
return numpy_helper.from_array(w, name=w_name)
6467

68+
def convert_b(b, b_name):
69+
b = tlx.convert_to_numpy(b)
70+
return numpy_helper.from_array(b, name=b_name)
6571

6672
def convert_tlx_relu(inputs, outputs, act = None):
6773
opsets = OpMapper.OPSETS['ReLU']
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#! /usr/bin/python
22
# -*- coding: utf-8 -*-
33

4-
from . import matmul
4+
from . import matmul
5+
from . import add

tlx2onnx/op_mapper/math/add.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
from onnx import helper, TensorProto
5+
from ..op_mapper import OpMapper
6+
from ...common import make_node, transpose_shape
7+
from tlx2onnx.op_mapper.datatype_mapping import NP_TYPE_TO_TENSOR_TYPE
8+
9+
@OpMapper('Add')
10+
class Add():
11+
# supports v7-v12
12+
13+
@classmethod
14+
def version_7(cls, node, **kwargs):
15+
onnx_node = []
16+
onnx_value = []
17+
onnx_init = []
18+
19+
x_name = node['in_nodes_name'][0]
20+
y_name = node['in_nodes_name'][1]
21+
out_name = node['out_nodes_name'][0]
22+
# x_shape = node['in_tensors'][0]
23+
# y_shape = node['in_tensors'][1]
24+
# out_shape = node['out_tensors'][0]
25+
26+
op_type = 'Add'
27+
add_node, _ = make_node(op_type, inputs=[x_name, y_name], outputs=[out_name])
28+
onnx_node.append(add_node)
29+
return onnx_node, onnx_value, onnx_init
30+

tlx2onnx/op_mapper/nn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,5 @@
2323
from .stack import *
2424
from .subpixelconv import *
2525
from .mask_conv import *
26-
from .group_conv import *
26+
from .groupconv import *
2727

tlx2onnx/op_mapper/nn/conv.py

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,64 +7,8 @@
77
from tlx2onnx.op_mapper.datatype_mapping import NP_TYPE_TO_TENSOR_TYPE
88
from tlx2onnx.op_mapper.op_mapper import OpMapper
99
from tlx2onnx.common import make_node
10-
from tlx2onnx.common import make_shape_channels_first, get_channels_first_permutation,tlx_act_2_onnx,get_channels_last_permutation
11-
12-
13-
def convert_padding(padding, input_shape, output_shape, kernel_shape, strides, dilations, spatial, data_format):
14-
if isinstance(padding, str):
15-
if padding == "SAME":
16-
pads = [0] * (spatial * 2)
17-
if data_format == "channels_last":
18-
input_shape = make_shape_channels_first(input_shape)
19-
output_shape = make_shape_channels_first(output_shape)
20-
21-
if any(input_shape[i + 2] == -1 or output_shape[i + 2] == -1 for i in range(spatial)):
22-
23-
auto_pad = "SAME_UPPER"
24-
25-
return auto_pad
26-
27-
for i in range(spatial):
28-
pad = (
29-
(output_shape[i + 2] - 1) * strides[i]
30-
+ dilations[i] * (kernel_shape[i] - 1) + 1
31-
- input_shape[i + 2]
32-
)
33-
pad = max(pad, 0)
34-
pads[i] = pad // 2
35-
pads[i + spatial] = pad - pad // 2
36-
37-
return pads
38-
39-
elif padding == "VALID":
40-
auto_pad = "VALID"
41-
return auto_pad
42-
elif isinstance(padding, int):
43-
pads = [padding] * spatial * 2
44-
return pads
45-
elif isinstance(padding, tuple):
46-
return list(padding) * 2
47-
48-
def convert_w(w, data_format, spatial, w_name):
49-
w = tlx.convert_to_numpy(w)
50-
if tlx.BACKEND == 'tensorflow':
51-
if spatial == 2:
52-
w = np.transpose(w, axes=[3, 2, 0, 1])
53-
elif spatial == 1:
54-
w = np.transpose(w, axes=[2, 1, 0])
55-
elif spatial == 3:
56-
w = np.transpose(w, axes=[4, 3, 0, 1, 2])
57-
return numpy_helper.from_array(w, name=w_name)
58-
elif tlx.BACKEND == 'mindspore':
59-
if spatial == 2 and data_format == 'channels_last':
60-
w = np.transpose(w, axes=[3, 0, 1, 2])
61-
return numpy_helper.from_array(w, name=w_name)
62-
return numpy_helper.from_array(w, name=w_name)
63-
64-
def convert_b(b, b_name):
65-
b = tlx.convert_to_numpy(b)
66-
return numpy_helper.from_array(b, name=b_name)
67-
10+
from tlx2onnx.common import make_shape_channels_first, get_channels_first_permutation,get_channels_last_permutation
11+
from tlx2onnx.common import convert_padding, convert_w, tlx_act_2_onnx, convert_b
6812

6913
@OpMapper(["Conv1d", "Conv2d", "Conv3d"])
7014
class Conv():

tlx2onnx/op_mapper/nn/group_conv.py renamed to tlx2onnx/op_mapper/nn/groupconv.py

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,64 +7,8 @@
77
from tlx2onnx.op_mapper.datatype_mapping import NP_TYPE_TO_TENSOR_TYPE
88
from tlx2onnx.op_mapper.op_mapper import OpMapper
99
from tlx2onnx.common import make_node
10-
from tlx2onnx.common import make_shape_channels_first, get_channels_first_permutation,tlx_act_2_onnx,get_channels_last_permutation
11-
12-
13-
def convert_padding(padding, input_shape, output_shape, kernel_shape, strides, dilations, spatial, data_format):
14-
if isinstance(padding, str):
15-
if padding == "SAME":
16-
pads = [0] * (spatial * 2)
17-
if data_format == "channels_last":
18-
input_shape = make_shape_channels_first(input_shape)
19-
output_shape = make_shape_channels_first(output_shape)
20-
21-
if any(input_shape[i + 2] == -1 or output_shape[i + 2] == -1 for i in range(spatial)):
22-
23-
auto_pad = "SAME_UPPER"
24-
25-
return auto_pad
26-
27-
for i in range(spatial):
28-
pad = (
29-
(output_shape[i + 2] - 1) * strides[i]
30-
+ dilations[i] * (kernel_shape[i] - 1) + 1
31-
- input_shape[i + 2]
32-
)
33-
pad = max(pad, 0)
34-
pads[i] = pad // 2
35-
pads[i + spatial] = pad - pad // 2
36-
37-
return pads
38-
39-
elif padding == "VALID":
40-
auto_pad = "VALID"
41-
return auto_pad
42-
elif isinstance(padding, int):
43-
pads = [padding] * spatial * 2
44-
return pads
45-
elif isinstance(padding, tuple):
46-
return list(padding) * 2
47-
48-
def convert_w(w, data_format, spatial, w_name):
49-
w = tlx.convert_to_numpy(w)
50-
if tlx.BACKEND == 'tensorflow':
51-
if spatial == 2:
52-
w = np.transpose(w, axes=[3, 2, 0, 1])
53-
elif spatial == 1:
54-
w = np.transpose(w, axes=[2, 1, 0])
55-
elif spatial == 3:
56-
w = np.transpose(w, axes=[4, 3, 0, 1, 2])
57-
return numpy_helper.from_array(w, name=w_name)
58-
elif tlx.BACKEND == 'mindspore':
59-
if spatial == 2 and data_format == 'channels_last':
60-
w = np.transpose(w, axes=[3, 0, 1, 2])
61-
return numpy_helper.from_array(w, name=w_name)
62-
return numpy_helper.from_array(w, name=w_name)
63-
64-
def convert_b(b, b_name):
65-
b = tlx.convert_to_numpy(b)
66-
return numpy_helper.from_array(b, name=b_name)
67-
10+
from tlx2onnx.common import make_shape_channels_first, get_channels_first_permutation,get_channels_last_permutation
11+
from tlx2onnx.common import convert_padding, convert_w, tlx_act_2_onnx, convert_b
6812

6913
@OpMapper(["GroupConv2d"])
7014
class Conv():

0 commit comments

Comments
 (0)