55from tlx2onnx .op_mapper .datatype_mapping import NP_TYPE_TO_TENSOR_TYPE
66from tlx2onnx .op_mapper .op_mapper import OpMapper
77from tlx2onnx .common import make_node , to_numpy , make_shape_channels_first , make_shape_channels_last , \
8- get_channels_first_permutation , get_channels_last_permutation
8+ get_channels_first_permutation , get_channels_last_permutation , tlx_act_2_onnx
99
1010
1111@OpMapper (['BatchNorm' , 'BatchNorm1d' , 'BatchNorm2d' , 'BatchNorm3d' ])
@@ -66,6 +66,12 @@ def version_1(cls, node, **kwargs):
6666 outputs = [node ['out_nodes_name' ][0 ] + 'bn' ]
6767 )
6868 onnx_node .append (bn_node )
69+
70+ if node ['node' ].layer .act is not None :
71+ act_op = node ['node' ].layer .act .__class__ .__name__
72+ act_node , out = tlx_act_2_onnx [act_op ]([out ], [node ['out_nodes_name' ][0 ] + 'act' ], node ['node' ].layer .act )
73+ onnx_node .append (act_node )
74+
6975 # make channels transpose
7076 t_out = helper .make_tensor_value_info (node ['out_nodes_name' ][0 ], NP_TYPE_TO_TENSOR_TYPE [node ['dtype' ]], shape = out_shape )
7177 onnx_value .append (t_out )
@@ -74,11 +80,22 @@ def version_1(cls, node, **kwargs):
7480
7581
7682 elif data_format == 'channels_first' :
77- bn_node , out = make_node ('BatchNormalization' ,
78- inputs = [x , beta_name , gamma_name , mean_name , var_name ],
79- outputs = node ['out_nodes_name' ]
80- )
81- onnx_node .append (bn_node )
83+ if node ['node' ].layer .act is None :
84+ bn_node , out = make_node ('BatchNormalization' ,
85+ inputs = [x , beta_name , gamma_name , mean_name , var_name ],
86+ outputs = node ['out_nodes_name' ]
87+ )
88+ onnx_node .append (bn_node )
89+ else :
90+ bn_node , out = make_node ('BatchNormalization' ,
91+ inputs = [x , beta_name , gamma_name , mean_name , var_name ],
92+ outputs = [node ['out_nodes_name' ][0 ] + 'bn' ]
93+ )
94+ onnx_node .append (bn_node )
95+ act_op = node ['node' ].layer .act .__class__ .__name__
96+ act_node , out = tlx_act_2_onnx [act_op ]([out ], node ['out_nodes_name' ], node ['node' ].layer .act )
97+ onnx_node .append (act_node )
98+
8299 return onnx_node , onnx_value , onnx_init
83100
84101
@@ -133,6 +150,12 @@ def version_17(cls, node, **kwargs):
133150 outputs = [node ['out_nodes_name' ][0 ] + 'bn' ], epsilon = epsilon
134151 )
135152 onnx_node .append (ln_node )
153+
154+ if node ['node' ].layer .act is not None :
155+ act_op = node ['node' ].layer .act .__class__ .__name__
156+ act_node , out = tlx_act_2_onnx [act_op ]([out ], [node ['out_nodes_name' ][0 ] + 'act' ], node ['node' ].layer .act )
157+ onnx_node .append (act_node )
158+
136159 # make channels transpose
137160 t_out = helper .make_tensor_value_info (node ['out_nodes_name' ][0 ], NP_TYPE_TO_TENSOR_TYPE [node ['dtype' ]], shape = out_shape )
138161 onnx_value .append (t_out )
0 commit comments