Skip to content

Commit a8ee95a

Browse files
authored
Update Normalization (#17)
1 parent 472d4fd commit a8ee95a

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

tests/test_batchnorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self):
1616
super(MLP, self).__init__()
1717
# weights init
1818
self.conv1 = Conv2d(out_channels=16, kernel_size=3, stride=1, padding=(2, 2), in_channels=3, data_format='channels_last', act = tlx.nn.ReLU)
19-
self.bn = BatchNorm2d(data_format='channels_last')
19+
self.bn = BatchNorm2d(data_format='channels_last', act=tlx.nn.ReLU)
2020

2121

2222
def forward(self, x):

tests/test_layernorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
class NET(Module):
1515
def __init__(self):
1616
super(NET, self).__init__()
17-
self.layernorm = LayerNorm([50, 50, 32])
17+
self.layernorm = LayerNorm([50, 50, 32], act=tlx.nn.ReLU)
1818

1919
def forward(self, x):
2020
x = self.layernorm(x)

tlx2onnx/op_mapper/nn/normalization.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from tlx2onnx.op_mapper.datatype_mapping import NP_TYPE_TO_TENSOR_TYPE
66
from tlx2onnx.op_mapper.op_mapper import OpMapper
77
from tlx2onnx.common import make_node, to_numpy, make_shape_channels_first, make_shape_channels_last, \
8-
get_channels_first_permutation, get_channels_last_permutation
8+
get_channels_first_permutation, get_channels_last_permutation, tlx_act_2_onnx
99

1010

1111
@OpMapper(['BatchNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d'])
@@ -66,6 +66,12 @@ def version_1(cls, node, **kwargs):
6666
outputs=[node['out_nodes_name'][0] + 'bn']
6767
)
6868
onnx_node.append(bn_node)
69+
70+
if node['node'].layer.act is not None:
71+
act_op = node['node'].layer.act.__class__.__name__
72+
act_node, out = tlx_act_2_onnx[act_op]([out], [node['out_nodes_name'][0] + 'act'], node['node'].layer.act)
73+
onnx_node.append(act_node)
74+
6975
# make channels transpose
7076
t_out = helper.make_tensor_value_info(node['out_nodes_name'][0], NP_TYPE_TO_TENSOR_TYPE[node['dtype']], shape=out_shape)
7177
onnx_value.append(t_out)
@@ -74,11 +80,22 @@ def version_1(cls, node, **kwargs):
7480

7581

7682
elif data_format == 'channels_first':
77-
bn_node, out = make_node('BatchNormalization',
78-
inputs=[x, beta_name, gamma_name, mean_name, var_name],
79-
outputs=node['out_nodes_name']
80-
)
81-
onnx_node.append(bn_node)
83+
if node['node'].layer.act is None:
84+
bn_node, out = make_node('BatchNormalization',
85+
inputs=[x, beta_name, gamma_name, mean_name, var_name],
86+
outputs=node['out_nodes_name']
87+
)
88+
onnx_node.append(bn_node)
89+
else:
90+
bn_node, out = make_node('BatchNormalization',
91+
inputs=[x, beta_name, gamma_name, mean_name, var_name],
92+
outputs=[node['out_nodes_name'][0] + 'bn']
93+
)
94+
onnx_node.append(bn_node)
95+
act_op = node['node'].layer.act.__class__.__name__
96+
act_node, out = tlx_act_2_onnx[act_op]([out], node['out_nodes_name'], node['node'].layer.act)
97+
onnx_node.append(act_node)
98+
8299
return onnx_node, onnx_value, onnx_init
83100

84101

@@ -133,6 +150,12 @@ def version_17(cls, node, **kwargs):
133150
outputs=[node['out_nodes_name'][0] + 'bn'], epsilon=epsilon
134151
)
135152
onnx_node.append(ln_node)
153+
154+
if node['node'].layer.act is not None:
155+
act_op = node['node'].layer.act.__class__.__name__
156+
act_node, out = tlx_act_2_onnx[act_op]([out], [node['out_nodes_name'][0] + 'act'], node['node'].layer.act)
157+
onnx_node.append(act_node)
158+
136159
# make channels transpose
137160
t_out = helper.make_tensor_value_info(node['out_nodes_name'][0], NP_TYPE_TO_TENSOR_TYPE[node['dtype']], shape=out_shape)
138161
onnx_value.append(t_out)

0 commit comments

Comments
 (0)