@@ -79,4 +79,63 @@ def version_1(cls, node, **kwargs):
7979 outputs = node ['out_nodes_name' ]
8080 )
8181 onnx_node .append (bn_node )
82+ return onnx_node , onnx_value , onnx_init
83+
84+
85+ @OpMapper (['LayerNorm' ])
86+ class LayerNorm ():
87+ # supports v17
88+
89+ @classmethod
90+ def version_17 (cls , node , ** kwargs ):
91+ onnx_node = []
92+ onnx_value = []
93+ onnx_init = []
94+
95+ # input , output, data_format
96+ x = node ['in_nodes_name' ][0 ]
97+ x_shape = node ['in_tensors' ][0 ]
98+
99+ out_shape = node ['out_tensors' ][0 ]
100+ out_v = helper .make_tensor_value_info (node ['out_nodes_name' ][0 ], NP_TYPE_TO_TENSOR_TYPE [node ['dtype' ]],
101+ shape = node ['out_tensors' ][0 ])
102+ onnx_value .append (out_v )
103+
104+ spatial = 2
105+ # get parameters
106+ beta_name = node ['node' ].layer .name + '/beta'
107+ beta_weight = numpy_helper .from_array (arr = to_numpy (node ['node' ].layer .beta ), name = beta_name )
108+ onnx_init .append (beta_weight )
109+
110+ gamma_name = node ['node' ].layer .name + '/gamma'
111+ gamma_weight = numpy_helper .from_array (arr = to_numpy (node ['node' ].layer .gamma ), name = gamma_name )
112+ onnx_init .append (gamma_weight )
113+
114+ epsilon = node ['node' ].layer .epsilon
115+
116+ # if data_format == 'channels_last':
117+ # channels last conver weights and input
118+ x_shape = make_shape_channels_first (x_shape )
119+ out_temp_shape = make_shape_channels_first (out_shape )
120+ # make channels transpose
121+ t_x = helper .make_tensor_value_info (node ['in_nodes_name' ][0 ] + 't' ,
122+ NP_TYPE_TO_TENSOR_TYPE [node ['dtype' ]], shape = x_shape )
123+ onnx_value .append (t_x )
124+ tx_node , x = make_node ('Transpose' , inputs = [x ], outputs = [node ['in_nodes_name' ][0 ] + 't' ],
125+ perm = get_channels_first_permutation (spatial ))
126+ onnx_node .append (tx_node )
127+ # make batch normalization
128+ out_temp = helper .make_tensor_value_info (node ['out_nodes_name' ][0 ] + 'bn' ,
129+ NP_TYPE_TO_TENSOR_TYPE [node ['dtype' ]], shape = out_temp_shape )
130+ onnx_value .append (out_temp )
131+ ln_node , out = make_node ('LayerNormalization' ,
132+ inputs = [node ['in_nodes_name' ][0 ] + 't' , beta_name , gamma_name ],
133+ outputs = [node ['out_nodes_name' ][0 ] + 'bn' ], epsilon = epsilon
134+ )
135+ onnx_node .append (ln_node )
136+ # make channels transpose
137+ t_out = helper .make_tensor_value_info (node ['out_nodes_name' ][0 ], NP_TYPE_TO_TENSOR_TYPE [node ['dtype' ]], shape = out_shape )
138+ onnx_value .append (t_out )
139+ tout_node , _ = make_node ('Transpose' , inputs = [out ], outputs = node ['out_nodes_name' ], perm = get_channels_last_permutation (spatial ))
140+ onnx_node .append (tout_node )
82141 return onnx_node , onnx_value , onnx_init
0 commit comments