Skip to content

Commit 8fbba88

Browse files
committed
add gru converter
1 parent d35bb36 commit 8fbba88

File tree

2 files changed

+248
-9
lines changed

2 files changed

+248
-9
lines changed

tests/test_rnn.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
#! /usr/bin/python
22
# -*- coding: utf-8 -*-
33
import os
4-
os.environ["TL_BACKEND"] = 'tensorflow'
4+
# os.environ["TL_BACKEND"] = 'tensorflow'
55
# os.environ["TL_BACKEND"] = 'paddle'
6-
# os.environ["TL_BACKEND"] = 'torch'
6+
os.environ["TL_BACKEND"] = 'torch'
77
# os.environ["TL_BACKEND"] = 'mindspore'
88
import tensorlayerx as tlx
99
from tensorlayerx.nn import Module
10-
from tensorlayerx.nn import RNN, LSTM, Linear
10+
from tensorlayerx.nn import RNN, LSTM, GRU, Linear
1111
from tlx2onnx.main import export
1212
import onnxruntime as rt
1313
import numpy as np
14+
tlx.set_seed(42)
1415

1516
class Rnn(Module):
1617

1718
def __init__(self):
1819
super(Rnn, self).__init__()
19-
self.l1 = Linear(in_features=5, out_features=7)
2020
self.rnn = RNN(input_size=5, hidden_size=5, act='relu', bidirectional=False, num_layers=4)
2121
def forward(self, x):
2222
x, _ = self.rnn(x)
@@ -34,14 +34,14 @@ def forward(self, x):
3434

3535
input_data = np.array(input, dtype=np.float32)
3636
result = sess.run([output_name], {input_name: input_data})
37-
print("rnn onnx output", result)
37+
print("RNN onnx output", result)
38+
print("==========================================================================================================")
3839

3940

4041
class Lstm(Module):
4142

4243
def __init__(self):
4344
super(Lstm, self).__init__()
44-
self.l1 = Linear(in_features=5, out_features=7)
4545
self.rnn = LSTM(input_size=5, hidden_size=5, bidirectional=True, num_layers=4)
4646
def forward(self, x):
4747
x, _ = self.rnn(x)
@@ -59,4 +59,29 @@ def forward(self, x):
5959

6060
input_data = np.array(input, dtype=np.float32)
6161
result = sess.run([output_name], {input_name: input_data})
62-
print("LSTM onnx output", result)
62+
print("LSTM onnx output", result)
63+
print("==========================================================================================================")
64+
65+
class Gru(Module):
66+
67+
def __init__(self):
68+
super(Gru, self).__init__()
69+
self.rnn = GRU(input_size=5, hidden_size=5, bidirectional=True, num_layers=4)
70+
def forward(self, x):
71+
x, _ = self.rnn(x)
72+
return x
73+
74+
model = Gru()
75+
input = tlx.nn.Input(shape=[1, 5, 5])
76+
model.set_eval()
77+
output = model(input)
78+
print("GRU tlx output", output)
79+
onnx_model = export(model, input_spec=input, path='gru.onnx')
80+
81+
sess = rt.InferenceSession('gru.onnx')
82+
input_name = sess.get_inputs()[0].name
83+
output_name = sess.get_outputs()[0].name
84+
85+
input_data = np.array(input, dtype=np.float32)
86+
result = sess.run([output_name], {input_name: input_data})
87+
print("GRU onnx output", result)

tlx2onnx/op_mapper/nn/rnn.py

Lines changed: 216 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def concat_states(cls, num, states, bidrectional):
9696

9797

9898
@classmethod
99-
def version_1(cls, node, **kwargs):
99+
def version_7(cls, node, **kwargs):
100100
onnx_node = []
101101
onnx_value = []
102102
onnx_init = []
@@ -307,7 +307,7 @@ def concat_states(cls, num, states, bidrectional):
307307
return states_hi, states_ci
308308

309309
@classmethod
310-
def version_1(cls, node, **kwargs):
310+
def version_7(cls, node, **kwargs):
311311
onnx_node = []
312312
onnx_value = []
313313
onnx_init = []
@@ -356,6 +356,220 @@ def version_1(cls, node, **kwargs):
356356
attr_dict["activations"] = ['Sigmoid', 'Tanh', 'Tanh'] * bidirect
357357
attr_dict["input_forget"] = 0
358358

359+
def name(num, name):
360+
return layer_name + '_' + name + '_' + str(num)
361+
362+
input = x_name
363+
for i in range(num_layers):
364+
w_i, r_i, b_i = cls.concat_params(i, weight_ih, weight_hh, bias_ih, bias_hh, bidirectional, hidden_size)
365+
attr_dict["inputs"] = [input]
366+
w_i_name = name(i, "w")
367+
attr_dict["inputs"].append(w_i_name)
368+
w_i_init = numpy_helper.from_array(w_i, w_i_name)
369+
onnx_init.append(w_i_init)
370+
r_i_name = name(i, 'r')
371+
attr_dict["inputs"].append(r_i_name)
372+
r_i_init = numpy_helper.from_array(r_i, r_i_name)
373+
onnx_init.append(r_i_init)
374+
if b_i is not None:
375+
b_i_name = name(i, 'b')
376+
attr_dict["inputs"].append(b_i_name)
377+
b_i_init = numpy_helper.from_array(b_i, b_i_name)
378+
onnx_init.append(b_i_init)
379+
else:
380+
attr_dict["inputs"].append("")
381+
# add sequence_lens into inputs
382+
if states is not None:
383+
state_hi_name = name(i, 'h')
384+
attr_dict["inputs"].append("")
385+
attr_dict["inputs"].append(state_hi_name)
386+
387+
state_ci_name = name(i, 'c')
388+
attr_dict["inputs"].append(state_ci_name)
389+
state_hi, state_ci = cls.concat_states(i, states, bidirectional)
390+
state_hi_init = numpy_helper.from_array(state_hi, state_hi_name)
391+
onnx_init.append(state_hi_init)
392+
state_ci_init = numpy_helper.from_array(state_ci, state_ci_name)
393+
onnx_init.append(state_ci_init)
394+
395+
attr_dict["outputs"] = [name(i, 'y')]
396+
rnn_node, y_out = make_node(op_type, **attr_dict)
397+
onnx_node.append(rnn_node)
398+
transpose_node, y_out_T = make_node("Transpose", inputs=[y_out], outputs=[y_out + "_T"], perm=[0,2,1,3])
399+
onnx_node.append(transpose_node)
400+
shape = np.array([0, 0, -1], dtype=np.int64)
401+
shape_name = name(i, 'shape')
402+
shape_value = numpy_helper.from_array(shape, shape_name)
403+
onnx_init.append(shape_value)
404+
if i + 1 < num_layers:
405+
reshape_output = [y_out + "_R"]
406+
reshape_node, y_out_R = make_node("Reshape", inputs=[y_out_T, shape_name], outputs=reshape_output)
407+
input = y_out_R
408+
else:
409+
reshape_node, y_out_R = make_node("Reshape", inputs=[y_out_T, shape_name], outputs=[out_name])
410+
onnx_node.append(reshape_node)
411+
412+
return onnx_node, onnx_value, onnx_init
413+
414+
@OpMapper(["GRU"])
415+
class RNN():
416+
# suppport v1-v11
417+
418+
@classmethod
419+
def concat_params(cls, num, weight_ih, weight_hh, bias_ih, bias_hh, bidrectional, hidden_size):
420+
421+
def reform_weights(weights, hidden_size):
422+
reform_permutaion = [(1, 2), (0, 1), (2, 3)]
423+
slices = []
424+
for x, y in reform_permutaion:
425+
start = x * hidden_size
426+
end = y * hidden_size
427+
slices.append(weights[start:end])
428+
return np.concatenate(slices, axis=0)
429+
ih_i, hh_i, b_i = None, None, None
430+
if bidrectional:
431+
id = num * 2
432+
# get i-th rnn layer's weights - ih_i input to hidden
433+
ih_i_forward = weight_ih[id]
434+
ih_i_reverse = weight_ih[id + 1]
435+
ih_i_forward = tlx.convert_to_numpy(ih_i_forward)
436+
ih_i_reverse = tlx.convert_to_numpy(ih_i_reverse)
437+
ih_i_forward = reform_weights(ih_i_forward, hidden_size)
438+
ih_i_reverse = reform_weights(ih_i_reverse, hidden_size)
439+
ih_i_forward = ih_i_forward[np.newaxis, :, :]
440+
ih_i_reverse = ih_i_reverse[np.newaxis, :, :]
441+
ih_i = np.concatenate((ih_i_forward, ih_i_reverse), axis=0)
442+
443+
# get i-th rnn layer's weights - hh_i hidden to hidden
444+
hh_i_forward = weight_hh[id]
445+
hh_i_reverse = weight_hh[id + 1]
446+
hh_i_forward = tlx.convert_to_numpy(hh_i_forward)
447+
hh_i_reverse = tlx.convert_to_numpy(hh_i_reverse)
448+
hh_i_forward = reform_weights(hh_i_forward, hidden_size)
449+
hh_i_reverse = reform_weights(hh_i_reverse, hidden_size)
450+
hh_i_forward = hh_i_forward[np.newaxis, :, :]
451+
hh_i_reverse = hh_i_reverse[np.newaxis, :, :]
452+
hh_i = np.concatenate((hh_i_forward, hh_i_reverse), axis=0)
453+
454+
if bias_ih is not None:
455+
# get i-th rnn layer's bias - ih_b input to hidden
456+
b_ih_forward = bias_ih[id]
457+
b_ih_reverse = bias_ih[id + 1]
458+
b_ih_forward = tlx.convert_to_numpy(b_ih_forward)
459+
b_ih_reverse = tlx.convert_to_numpy(b_ih_reverse)
460+
b_ih_forward = reform_weights(b_ih_forward, hidden_size)
461+
b_ih_reverse = reform_weights(b_ih_reverse, hidden_size)
462+
b_ih_forward = b_ih_forward[np.newaxis, :]
463+
b_ih_reverse = b_ih_reverse[np.newaxis, :]
464+
# get i-th rnn layer's bias - hh_b hidden to hidden
465+
b_hh_forward = bias_hh[id]
466+
b_hh_reverse = bias_hh[id + 1]
467+
b_hh_forward = tlx.convert_to_numpy(b_hh_forward)
468+
b_hh_reverse = tlx.convert_to_numpy(b_hh_reverse)
469+
b_hh_forward = reform_weights(b_hh_forward, hidden_size)
470+
b_hh_reverse = reform_weights(b_hh_reverse, hidden_size)
471+
b_hh_forward = b_hh_forward[np.newaxis, :]
472+
b_hh_reverse = b_hh_reverse[np.newaxis, :]
473+
474+
# concat bias
475+
b_forward = np.concatenate((b_ih_forward, b_hh_forward), axis=-1)
476+
b_reverse = np.concatenate((b_ih_reverse, b_hh_reverse), axis=-1)
477+
b_i = np.concatenate((b_forward, b_reverse), axis=0)
478+
else:
479+
# get i-th rnn layer's weights - ih_i input to hidden
480+
ih_i_forward = weight_ih[num]
481+
ih_i_forward = tlx.convert_to_numpy(ih_i_forward)
482+
ih_i_forward = reform_weights(ih_i_forward, hidden_size)
483+
ih_i = ih_i_forward[np.newaxis, :, :]
484+
485+
# get i-th rnn layer's weights - hh_i hidden to hidden
486+
hh_i_forward = weight_hh[num]
487+
hh_i_forward = tlx.convert_to_numpy(hh_i_forward)
488+
hh_i_forward = reform_weights(hh_i_forward, hidden_size)
489+
hh_i = hh_i_forward[np.newaxis, :, :]
490+
491+
if bias_ih is not None:
492+
# get i-th rnn layer's bias - ih_b input to hidden
493+
b_ih_forward = bias_ih[num]
494+
b_ih_forward = tlx.convert_to_numpy(b_ih_forward)
495+
b_ih_forward = reform_weights(b_ih_forward, hidden_size)
496+
b_ih_forward = b_ih_forward[np.newaxis, :]
497+
# get i-th rnn layer's bias - hh_b hidden to hidden
498+
b_hh_forward = bias_hh[num]
499+
b_hh_forward = tlx.convert_to_numpy(b_hh_forward)
500+
b_hh_forward = reform_weights(b_hh_forward, hidden_size)
501+
b_hh_forward = b_hh_forward[np.newaxis, :]
502+
503+
# concat bias
504+
b_i = np.concatenate((b_ih_forward, b_hh_forward), axis=-1)
505+
506+
return ih_i, hh_i, b_i
507+
508+
@classmethod
509+
def concat_states(cls, num, states, bidrectional):
510+
states_h = tlx.convert_to_numpy(states[0])
511+
states_c = tlx.convert_to_numpy(states[1])
512+
if bidrectional:
513+
id = num * 2
514+
states_hi = states_h[id: id+2, :, :]
515+
states_ci = states_c[id: id+2, :, :]
516+
else:
517+
states_hi = states_h[num, :, :]
518+
states_hi = states_hi[np.newaxis, :, :]
519+
states_ci = states_c[num, :, :]
520+
states_ci = states_ci[np.newaxis, :, :]
521+
return states_hi, states_ci
522+
523+
@classmethod
524+
def version_7(cls, node, **kwargs):
525+
onnx_node = []
526+
onnx_value = []
527+
onnx_init = []
528+
529+
op_type = "GRU"
530+
attr_dict = OrderedDict()
531+
# get in_node_name out_node_nmae
532+
x_name = node['in_nodes_name'][0]
533+
out_name = node['out_nodes_name'][0]
534+
x_shape = node['in_tensors'][0]
535+
out_shape = node['out_tensors'][0]
536+
537+
#### get data_type
538+
data_type = node['dtype']
539+
tensor_type = NP_TYPE_TO_TENSOR_TYPE[data_type]
540+
541+
# get cur_node_layer node_index
542+
layer = node['node'].layer
543+
layer_name = layer.__class__.__name__
544+
545+
546+
# get layer attr
547+
input_size = layer.input_size
548+
hidden_size = layer.hidden_size
549+
num_layers = layer.num_layers
550+
bias = layer.bias
551+
batch_first = layer.batch_first
552+
# dropout = layer.dropout # we don't need dropout inference
553+
bidirectional = layer.bidirectional
554+
states = layer.states
555+
# new_states = layer.new_states
556+
bidirect = 2 if bidirectional else 1
557+
558+
#get layer weights
559+
weight_ih = layer.weight_ih
560+
weight_hh = layer.weight_hh
561+
bias_ih = None
562+
bias_hh = None
563+
if bias:
564+
bias_ih = layer.bias_ih
565+
bias_hh = layer.bias_hh
566+
567+
attr_dict["direction"] = "bidirectional" if bidirectional else "forward"
568+
attr_dict["layout"] = 1 if batch_first else 0
569+
attr_dict["hidden_size"] = hidden_size
570+
attr_dict["activations"] = ['Sigmoid', 'Tanh'] * bidirect
571+
attr_dict["linear_before_reset"] = 1
572+
359573
def name(num, name):
360574
return layer_name + '_' + name + '_' + str(num)
361575

0 commit comments

Comments
 (0)