1313@OpMapper (["RNN" ])
1414class RNN ():
1515 # suppport v1-v11
16-
1716 @classmethod
1817 def concat_params (cls , num , weight_ih , weight_hh , bias_ih , bias_hh , bidrectional ):
1918 ih_i , hh_i , b_i = None , None , None
2019 if bidrectional :
2120 id = num * 2
2221 # get i-th rnn layer's weights - ih_i input to hidden
2322 ih_i_forward = weight_ih [id ]
24- ih_i_reverse = weight_ih [id + 1 ]
23+ ih_i_reverse = weight_ih [id + 1 ]
2524 ih_i_forward = tlx .convert_to_numpy (ih_i_forward )
2625 ih_i_reverse = tlx .convert_to_numpy (ih_i_reverse )
2726 ih_i_forward = ih_i_forward [np .newaxis , :, :]
@@ -40,14 +39,14 @@ def concat_params(cls, num, weight_ih, weight_hh, bias_ih, bias_hh, bidrectional
4039 if bias_ih is not None :
4140 # get i-th rnn layer's bias - ih_b input to hidden
4241 b_ih_forward = bias_ih [id ]
43- b_ih_reverse = bias_ih [id + 1 ]
42+ b_ih_reverse = bias_ih [id + 1 ]
4443 b_ih_forward = tlx .convert_to_numpy (b_ih_forward )
4544 b_ih_reverse = tlx .convert_to_numpy (b_ih_reverse )
4645 b_ih_forward = b_ih_forward [np .newaxis , :]
4746 b_ih_reverse = b_ih_reverse [np .newaxis , :]
4847 # get i-th rnn layer's bias - hh_b hidden to hidden
4948 b_hh_forward = bias_hh [id ]
50- b_hh_reverse = bias_hh [id + 1 ]
49+ b_hh_reverse = bias_hh [id + 1 ]
5150 b_hh_forward = tlx .convert_to_numpy (b_hh_forward )
5251 b_hh_reverse = tlx .convert_to_numpy (b_hh_reverse )
5352 b_hh_forward = b_hh_forward [np .newaxis , :]
@@ -128,7 +127,6 @@ def version_1(cls, node, **kwargs):
128127 # dropout = layer.dropout # we don't need dropout inference
129128 bidirectional = layer .bidirectional
130129 act = layer .mode [4 :]
131- print (act )
132130 states = layer .states
133131 # new_states = layer.new_states
134132 bidirect = 2 if bidirectional else 1
@@ -199,3 +197,216 @@ def name(num, name):
199197
200198
201199
200+ @OpMapper (["LSTM" ])
201+ class RNN ():
202+ # suppport v1-v11
203+
204+ @classmethod
205+ def concat_params (cls , num , weight_ih , weight_hh , bias_ih , bias_hh , bidrectional , hidden_size ):
206+
207+ def reform_weights (weights , hidden_size ):
208+ reform_permutaion = [(0 , 1 ), (3 , 4 ), (1 , 3 )]
209+ slices = []
210+ for x , y in reform_permutaion :
211+ start = x * hidden_size
212+ end = y * hidden_size
213+ slices .append (weights [start :end ])
214+ return np .concatenate (slices , axis = 0 )
215+ ih_i , hh_i , b_i = None , None , None
216+ if bidrectional :
217+ id = num * 2
218+ # get i-th rnn layer's weights - ih_i input to hidden
219+ ih_i_forward = weight_ih [id ]
220+ ih_i_reverse = weight_ih [id + 1 ]
221+ ih_i_forward = tlx .convert_to_numpy (ih_i_forward )
222+ ih_i_reverse = tlx .convert_to_numpy (ih_i_reverse )
223+ ih_i_forward = reform_weights (ih_i_forward , hidden_size )
224+ ih_i_reverse = reform_weights (ih_i_reverse , hidden_size )
225+ ih_i_forward = ih_i_forward [np .newaxis , :, :]
226+ ih_i_reverse = ih_i_reverse [np .newaxis , :, :]
227+ ih_i = np .concatenate ((ih_i_forward , ih_i_reverse ), axis = 0 )
228+
229+ # get i-th rnn layer's weights - hh_i hidden to hidden
230+ hh_i_forward = weight_hh [id ]
231+ hh_i_reverse = weight_hh [id + 1 ]
232+ hh_i_forward = tlx .convert_to_numpy (hh_i_forward )
233+ hh_i_reverse = tlx .convert_to_numpy (hh_i_reverse )
234+ hh_i_forward = reform_weights (hh_i_forward , hidden_size )
235+ hh_i_reverse = reform_weights (hh_i_reverse , hidden_size )
236+ hh_i_forward = hh_i_forward [np .newaxis , :, :]
237+ hh_i_reverse = hh_i_reverse [np .newaxis , :, :]
238+ hh_i = np .concatenate ((hh_i_forward , hh_i_reverse ), axis = 0 )
239+
240+ if bias_ih is not None :
241+ # get i-th rnn layer's bias - ih_b input to hidden
242+ b_ih_forward = bias_ih [id ]
243+ b_ih_reverse = bias_ih [id + 1 ]
244+ b_ih_forward = tlx .convert_to_numpy (b_ih_forward )
245+ b_ih_reverse = tlx .convert_to_numpy (b_ih_reverse )
246+ b_ih_forward = reform_weights (b_ih_forward , hidden_size )
247+ b_ih_reverse = reform_weights (b_ih_reverse , hidden_size )
248+ b_ih_forward = b_ih_forward [np .newaxis , :]
249+ b_ih_reverse = b_ih_reverse [np .newaxis , :]
250+ # get i-th rnn layer's bias - hh_b hidden to hidden
251+ b_hh_forward = bias_hh [id ]
252+ b_hh_reverse = bias_hh [id + 1 ]
253+ b_hh_forward = tlx .convert_to_numpy (b_hh_forward )
254+ b_hh_reverse = tlx .convert_to_numpy (b_hh_reverse )
255+ b_hh_forward = reform_weights (b_hh_forward , hidden_size )
256+ b_hh_reverse = reform_weights (b_hh_reverse , hidden_size )
257+ b_hh_forward = b_hh_forward [np .newaxis , :]
258+ b_hh_reverse = b_hh_reverse [np .newaxis , :]
259+
260+ # concat bias
261+ b_forward = np .concatenate ((b_ih_forward , b_hh_forward ), axis = - 1 )
262+ b_reverse = np .concatenate ((b_ih_reverse , b_hh_reverse ), axis = - 1 )
263+ b_i = np .concatenate ((b_forward , b_reverse ), axis = 0 )
264+ else :
265+ # get i-th rnn layer's weights - ih_i input to hidden
266+ ih_i_forward = weight_ih [num ]
267+ ih_i_forward = tlx .convert_to_numpy (ih_i_forward )
268+ ih_i_forward = reform_weights (ih_i_forward , hidden_size )
269+ ih_i = ih_i_forward [np .newaxis , :, :]
270+
271+ # get i-th rnn layer's weights - hh_i hidden to hidden
272+ hh_i_forward = weight_hh [num ]
273+ hh_i_forward = tlx .convert_to_numpy (hh_i_forward )
274+ hh_i_forward = reform_weights (hh_i_forward , hidden_size )
275+ hh_i = hh_i_forward [np .newaxis , :, :]
276+
277+ if bias_ih is not None :
278+ # get i-th rnn layer's bias - ih_b input to hidden
279+ b_ih_forward = bias_ih [num ]
280+ b_ih_forward = tlx .convert_to_numpy (b_ih_forward )
281+ b_ih_forward = reform_weights (b_ih_forward , hidden_size )
282+ b_ih_forward = b_ih_forward [np .newaxis , :]
283+ # get i-th rnn layer's bias - hh_b hidden to hidden
284+ b_hh_forward = bias_hh [num ]
285+ b_hh_forward = tlx .convert_to_numpy (b_hh_forward )
286+ b_hh_forward = reform_weights (b_hh_forward , hidden_size )
287+ b_hh_forward = b_hh_forward [np .newaxis , :]
288+
289+ # concat bias
290+ b_i = np .concatenate ((b_ih_forward , b_hh_forward ), axis = - 1 )
291+
292+ return ih_i , hh_i , b_i
293+
294+ @classmethod
295+ def concat_states (cls , num , states , bidrectional ):
296+ states_h = tlx .convert_to_numpy (states [0 ])
297+ states_c = tlx .convert_to_numpy (states [1 ])
298+ if bidrectional :
299+ id = num * 2
300+ states_hi = states_h [id : id + 2 , :, :]
301+ states_ci = states_c [id : id + 2 , :, :]
302+ else :
303+ states_hi = states_h [num , :, :]
304+ states_hi = states_hi [np .newaxis , :, :]
305+ states_ci = states_c [num , :, :]
306+ states_ci = states_ci [np .newaxis , :, :]
307+ return states_hi , states_ci
308+
309+ @classmethod
310+ def version_1 (cls , node , ** kwargs ):
311+ onnx_node = []
312+ onnx_value = []
313+ onnx_init = []
314+
315+ op_type = "LSTM"
316+ attr_dict = OrderedDict ()
317+ # get in_node_name out_node_nmae
318+ x_name = node ['in_nodes_name' ][0 ]
319+ out_name = node ['out_nodes_name' ][0 ]
320+ x_shape = node ['in_tensors' ][0 ]
321+ out_shape = node ['out_tensors' ][0 ]
322+
323+ #### get data_type
324+ data_type = node ['dtype' ]
325+ tensor_type = NP_TYPE_TO_TENSOR_TYPE [data_type ]
326+
327+ # get cur_node_layer node_index
328+ layer = node ['node' ].layer
329+ layer_name = layer .__class__ .__name__
330+
331+
332+ # get layer attr
333+ input_size = layer .input_size
334+ hidden_size = layer .hidden_size
335+ num_layers = layer .num_layers
336+ bias = layer .bias
337+ batch_first = layer .batch_first
338+ # dropout = layer.dropout # we don't need dropout inference
339+ bidirectional = layer .bidirectional
340+ states = layer .states
341+ # new_states = layer.new_states
342+ bidirect = 2 if bidirectional else 1
343+
344+ #get layer weights
345+ weight_ih = layer .weight_ih
346+ weight_hh = layer .weight_hh
347+ bias_ih = None
348+ bias_hh = None
349+ if bias :
350+ bias_ih = layer .bias_ih
351+ bias_hh = layer .bias_hh
352+
353+ attr_dict ["direction" ] = "bidirectional" if bidirectional else "forward"
354+ attr_dict ["layout" ] = 1 if batch_first else 0
355+ attr_dict ["hidden_size" ] = hidden_size
356+ attr_dict ["activations" ] = ['Sigmoid' , 'Tanh' , 'Tanh' ] * bidirect
357+ attr_dict ["input_forget" ] = 0
358+
359+ def name (num , name ):
360+ return layer_name + '_' + name + '_' + str (num )
361+
362+ input = x_name
363+ for i in range (num_layers ):
364+ w_i , r_i , b_i = cls .concat_params (i , weight_ih , weight_hh , bias_ih , bias_hh , bidirectional , hidden_size )
365+ attr_dict ["inputs" ] = [input ]
366+ w_i_name = name (i , "w" )
367+ attr_dict ["inputs" ].append (w_i_name )
368+ w_i_init = numpy_helper .from_array (w_i , w_i_name )
369+ onnx_init .append (w_i_init )
370+ r_i_name = name (i , 'r' )
371+ attr_dict ["inputs" ].append (r_i_name )
372+ r_i_init = numpy_helper .from_array (r_i , r_i_name )
373+ onnx_init .append (r_i_init )
374+ if b_i is not None :
375+ b_i_name = name (i , 'b' )
376+ attr_dict ["inputs" ].append (b_i_name )
377+ b_i_init = numpy_helper .from_array (b_i , b_i_name )
378+ onnx_init .append (b_i_init )
379+ else :
380+ attr_dict ["inputs" ].append ("" )
381+ # add sequence_lens into inputs
382+ if states is not None :
383+ state_hi_name = name (i , 'h' )
384+ attr_dict ["inputs" ].append ("" )
385+ attr_dict ["inputs" ].append (state_hi_name )
386+
387+ state_ci_name = name (i , 'c' )
388+ attr_dict ["inputs" ].append (state_ci_name )
389+ state_hi , state_ci = cls .concat_states (i , states , bidirectional )
390+ state_hi_init = numpy_helper .from_array (state_hi , state_hi_name )
391+ onnx_init .append (state_hi_init )
392+ state_ci_init = numpy_helper .from_array (state_ci , state_ci_name )
393+ onnx_init .append (state_ci_init )
394+
395+ attr_dict ["outputs" ] = [name (i , 'y' )]
396+ rnn_node , y_out = make_node (op_type , ** attr_dict )
397+ onnx_node .append (rnn_node )
398+ transpose_node , y_out_T = make_node ("Transpose" , inputs = [y_out ], outputs = [y_out + "_T" ], perm = [0 ,2 ,1 ,3 ])
399+ onnx_node .append (transpose_node )
400+ shape = np .array ([0 , 0 , - 1 ], dtype = np .int64 )
401+ shape_name = name (i , 'shape' )
402+ shape_value = numpy_helper .from_array (shape , shape_name )
403+ onnx_init .append (shape_value )
404+ if i + 1 < num_layers :
405+ reshape_output = [y_out + "_R" ]
406+ reshape_node , y_out_R = make_node ("Reshape" , inputs = [y_out_T , shape_name ], outputs = reshape_output )
407+ input = y_out_R
408+ else :
409+ reshape_node , y_out_R = make_node ("Reshape" , inputs = [y_out_T , shape_name ], outputs = [out_name ])
410+ onnx_node .append (reshape_node )
411+
412+ return onnx_node , onnx_value , onnx_init
0 commit comments