Skip to content

Commit 60dad1f

Browse files
committed
add lstm converter
1 parent c1c5cd0 commit 60dad1f

File tree

2 files changed

+251
-15
lines changed

2 files changed

+251
-15
lines changed

tests/test_rnn.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
11
#! /usr/bin/python
22
# -*- coding: utf-8 -*-
33
import os
4-
# os.environ["TL_BACKEND"] = 'tensorflow'
4+
os.environ["TL_BACKEND"] = 'tensorflow'
55
# os.environ["TL_BACKEND"] = 'paddle'
6-
os.environ["TL_BACKEND"] = 'torch'
6+
# os.environ["TL_BACKEND"] = 'torch'
77
# os.environ["TL_BACKEND"] = 'mindspore'
88
import tensorlayerx as tlx
99
from tensorlayerx.nn import Module
10-
from tensorlayerx.nn import RNN, LSTM
10+
from tensorlayerx.nn import RNN, LSTM, Linear
1111
from tlx2onnx.main import export
1212
import onnxruntime as rt
1313
import numpy as np
14-
tlx.set_seed(42)
1514

16-
class ImdbNet(Module):
15+
class Rnn(Module):
1716

1817
def __init__(self):
19-
super(ImdbNet, self).__init__()
20-
self.rnn = RNN(input_size=5, hidden_size=5, bidirectional=True, num_layers=4)
18+
super(Rnn, self).__init__()
19+
self.l1 = Linear(in_features=5, out_features=7)
20+
self.rnn = RNN(input_size=5, hidden_size=5, act='relu', bidirectional=False, num_layers=4)
2121
def forward(self, x):
2222
x, _ = self.rnn(x)
2323
return x
24-
model = ImdbNet()
25-
input = tlx.nn.Input(shape=[2, 2, 5])
24+
model = Rnn()
25+
input = tlx.nn.Input(shape=[1, 5, 5])
2626
model.set_eval()
2727
output = model(input)
2828
print("RNN tlx output", output)
@@ -34,4 +34,29 @@ def forward(self, x):
3434

3535
input_data = np.array(input, dtype=np.float32)
3636
result = sess.run([output_name], {input_name: input_data})
37-
print("rnn onnx output", result)
37+
print("rnn onnx output", result)
38+
39+
40+
class Lstm(Module):
41+
42+
def __init__(self):
43+
super(Lstm, self).__init__()
44+
self.l1 = Linear(in_features=5, out_features=7)
45+
self.rnn = LSTM(input_size=5, hidden_size=5, bidirectional=True, num_layers=4)
46+
def forward(self, x):
47+
x, _ = self.rnn(x)
48+
return x
49+
model = Lstm()
50+
input = tlx.nn.Input(shape=[1, 5, 5])
51+
model.set_eval()
52+
output = model(input)
53+
print("LSTM tlx output", output)
54+
onnx_model = export(model, input_spec=input, path='lstm.onnx')
55+
56+
sess = rt.InferenceSession('lstm.onnx')
57+
input_name = sess.get_inputs()[0].name
58+
output_name = sess.get_outputs()[0].name
59+
60+
input_data = np.array(input, dtype=np.float32)
61+
result = sess.run([output_name], {input_name: input_data})
62+
print("LSTM onnx output", result)

tlx2onnx/op_mapper/nn/rnn.py

Lines changed: 216 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@
1313
@OpMapper(["RNN"])
1414
class RNN():
1515
# suppport v1-v11
16-
1716
@classmethod
1817
def concat_params(cls, num, weight_ih, weight_hh, bias_ih, bias_hh, bidrectional):
1918
ih_i, hh_i, b_i = None, None, None
2019
if bidrectional:
2120
id = num * 2
2221
# get i-th rnn layer's weights - ih_i input to hidden
2322
ih_i_forward = weight_ih[id]
24-
ih_i_reverse = weight_ih[id+1]
23+
ih_i_reverse = weight_ih[id + 1]
2524
ih_i_forward = tlx.convert_to_numpy(ih_i_forward)
2625
ih_i_reverse = tlx.convert_to_numpy(ih_i_reverse)
2726
ih_i_forward = ih_i_forward[np.newaxis, :, :]
@@ -40,14 +39,14 @@ def concat_params(cls, num, weight_ih, weight_hh, bias_ih, bias_hh, bidrectional
4039
if bias_ih is not None:
4140
# get i-th rnn layer's bias - ih_b input to hidden
4241
b_ih_forward = bias_ih[id]
43-
b_ih_reverse = bias_ih[id+1]
42+
b_ih_reverse = bias_ih[id + 1]
4443
b_ih_forward = tlx.convert_to_numpy(b_ih_forward)
4544
b_ih_reverse = tlx.convert_to_numpy(b_ih_reverse)
4645
b_ih_forward = b_ih_forward[np.newaxis, :]
4746
b_ih_reverse = b_ih_reverse[np.newaxis, :]
4847
# get i-th rnn layer's bias - hh_b hidden to hidden
4948
b_hh_forward = bias_hh[id]
50-
b_hh_reverse = bias_hh[id+1]
49+
b_hh_reverse = bias_hh[id + 1]
5150
b_hh_forward = tlx.convert_to_numpy(b_hh_forward)
5251
b_hh_reverse = tlx.convert_to_numpy(b_hh_reverse)
5352
b_hh_forward = b_hh_forward[np.newaxis, :]
@@ -128,7 +127,6 @@ def version_1(cls, node, **kwargs):
128127
# dropout = layer.dropout # we don't need dropout inference
129128
bidirectional = layer.bidirectional
130129
act = layer.mode[4:]
131-
print(act)
132130
states = layer.states
133131
# new_states = layer.new_states
134132
bidirect = 2 if bidirectional else 1
@@ -199,3 +197,216 @@ def name(num, name):
199197

200198

201199

200+
@OpMapper(["LSTM"])
201+
class RNN():
202+
# suppport v1-v11
203+
204+
@classmethod
205+
def concat_params(cls, num, weight_ih, weight_hh, bias_ih, bias_hh, bidrectional, hidden_size):
206+
207+
def reform_weights(weights, hidden_size):
208+
reform_permutaion = [(0, 1), (3, 4), (1, 3)]
209+
slices = []
210+
for x, y in reform_permutaion:
211+
start = x * hidden_size
212+
end = y * hidden_size
213+
slices.append(weights[start:end])
214+
return np.concatenate(slices, axis=0)
215+
ih_i, hh_i, b_i = None, None, None
216+
if bidrectional:
217+
id = num * 2
218+
# get i-th rnn layer's weights - ih_i input to hidden
219+
ih_i_forward = weight_ih[id]
220+
ih_i_reverse = weight_ih[id + 1]
221+
ih_i_forward = tlx.convert_to_numpy(ih_i_forward)
222+
ih_i_reverse = tlx.convert_to_numpy(ih_i_reverse)
223+
ih_i_forward = reform_weights(ih_i_forward, hidden_size)
224+
ih_i_reverse = reform_weights(ih_i_reverse, hidden_size)
225+
ih_i_forward = ih_i_forward[np.newaxis, :, :]
226+
ih_i_reverse = ih_i_reverse[np.newaxis, :, :]
227+
ih_i = np.concatenate((ih_i_forward, ih_i_reverse), axis=0)
228+
229+
# get i-th rnn layer's weights - hh_i hidden to hidden
230+
hh_i_forward = weight_hh[id]
231+
hh_i_reverse = weight_hh[id + 1]
232+
hh_i_forward = tlx.convert_to_numpy(hh_i_forward)
233+
hh_i_reverse = tlx.convert_to_numpy(hh_i_reverse)
234+
hh_i_forward = reform_weights(hh_i_forward, hidden_size)
235+
hh_i_reverse = reform_weights(hh_i_reverse, hidden_size)
236+
hh_i_forward = hh_i_forward[np.newaxis, :, :]
237+
hh_i_reverse = hh_i_reverse[np.newaxis, :, :]
238+
hh_i = np.concatenate((hh_i_forward, hh_i_reverse), axis=0)
239+
240+
if bias_ih is not None:
241+
# get i-th rnn layer's bias - ih_b input to hidden
242+
b_ih_forward = bias_ih[id]
243+
b_ih_reverse = bias_ih[id + 1]
244+
b_ih_forward = tlx.convert_to_numpy(b_ih_forward)
245+
b_ih_reverse = tlx.convert_to_numpy(b_ih_reverse)
246+
b_ih_forward = reform_weights(b_ih_forward, hidden_size)
247+
b_ih_reverse = reform_weights(b_ih_reverse, hidden_size)
248+
b_ih_forward = b_ih_forward[np.newaxis, :]
249+
b_ih_reverse = b_ih_reverse[np.newaxis, :]
250+
# get i-th rnn layer's bias - hh_b hidden to hidden
251+
b_hh_forward = bias_hh[id]
252+
b_hh_reverse = bias_hh[id + 1]
253+
b_hh_forward = tlx.convert_to_numpy(b_hh_forward)
254+
b_hh_reverse = tlx.convert_to_numpy(b_hh_reverse)
255+
b_hh_forward = reform_weights(b_hh_forward, hidden_size)
256+
b_hh_reverse = reform_weights(b_hh_reverse, hidden_size)
257+
b_hh_forward = b_hh_forward[np.newaxis, :]
258+
b_hh_reverse = b_hh_reverse[np.newaxis, :]
259+
260+
# concat bias
261+
b_forward = np.concatenate((b_ih_forward, b_hh_forward), axis=-1)
262+
b_reverse = np.concatenate((b_ih_reverse, b_hh_reverse), axis=-1)
263+
b_i = np.concatenate((b_forward, b_reverse), axis=0)
264+
else:
265+
# get i-th rnn layer's weights - ih_i input to hidden
266+
ih_i_forward = weight_ih[num]
267+
ih_i_forward = tlx.convert_to_numpy(ih_i_forward)
268+
ih_i_forward = reform_weights(ih_i_forward, hidden_size)
269+
ih_i = ih_i_forward[np.newaxis, :, :]
270+
271+
# get i-th rnn layer's weights - hh_i hidden to hidden
272+
hh_i_forward = weight_hh[num]
273+
hh_i_forward = tlx.convert_to_numpy(hh_i_forward)
274+
hh_i_forward = reform_weights(hh_i_forward, hidden_size)
275+
hh_i = hh_i_forward[np.newaxis, :, :]
276+
277+
if bias_ih is not None:
278+
# get i-th rnn layer's bias - ih_b input to hidden
279+
b_ih_forward = bias_ih[num]
280+
b_ih_forward = tlx.convert_to_numpy(b_ih_forward)
281+
b_ih_forward = reform_weights(b_ih_forward, hidden_size)
282+
b_ih_forward = b_ih_forward[np.newaxis, :]
283+
# get i-th rnn layer's bias - hh_b hidden to hidden
284+
b_hh_forward = bias_hh[num]
285+
b_hh_forward = tlx.convert_to_numpy(b_hh_forward)
286+
b_hh_forward = reform_weights(b_hh_forward, hidden_size)
287+
b_hh_forward = b_hh_forward[np.newaxis, :]
288+
289+
# concat bias
290+
b_i = np.concatenate((b_ih_forward, b_hh_forward), axis=-1)
291+
292+
return ih_i, hh_i, b_i
293+
294+
@classmethod
295+
def concat_states(cls, num, states, bidrectional):
296+
states_h = tlx.convert_to_numpy(states[0])
297+
states_c = tlx.convert_to_numpy(states[1])
298+
if bidrectional:
299+
id = num * 2
300+
states_hi = states_h[id: id+2, :, :]
301+
states_ci = states_c[id: id+2, :, :]
302+
else:
303+
states_hi = states_h[num, :, :]
304+
states_hi = states_hi[np.newaxis, :, :]
305+
states_ci = states_c[num, :, :]
306+
states_ci = states_ci[np.newaxis, :, :]
307+
return states_hi, states_ci
308+
309+
@classmethod
310+
def version_1(cls, node, **kwargs):
311+
onnx_node = []
312+
onnx_value = []
313+
onnx_init = []
314+
315+
op_type = "LSTM"
316+
attr_dict = OrderedDict()
317+
# get in_node_name out_node_nmae
318+
x_name = node['in_nodes_name'][0]
319+
out_name = node['out_nodes_name'][0]
320+
x_shape = node['in_tensors'][0]
321+
out_shape = node['out_tensors'][0]
322+
323+
#### get data_type
324+
data_type = node['dtype']
325+
tensor_type = NP_TYPE_TO_TENSOR_TYPE[data_type]
326+
327+
# get cur_node_layer node_index
328+
layer = node['node'].layer
329+
layer_name = layer.__class__.__name__
330+
331+
332+
# get layer attr
333+
input_size = layer.input_size
334+
hidden_size = layer.hidden_size
335+
num_layers = layer.num_layers
336+
bias = layer.bias
337+
batch_first = layer.batch_first
338+
# dropout = layer.dropout # we don't need dropout inference
339+
bidirectional = layer.bidirectional
340+
states = layer.states
341+
# new_states = layer.new_states
342+
bidirect = 2 if bidirectional else 1
343+
344+
#get layer weights
345+
weight_ih = layer.weight_ih
346+
weight_hh = layer.weight_hh
347+
bias_ih = None
348+
bias_hh = None
349+
if bias:
350+
bias_ih = layer.bias_ih
351+
bias_hh = layer.bias_hh
352+
353+
attr_dict["direction"] = "bidirectional" if bidirectional else "forward"
354+
attr_dict["layout"] = 1 if batch_first else 0
355+
attr_dict["hidden_size"] = hidden_size
356+
attr_dict["activations"] = ['Sigmoid', 'Tanh', 'Tanh'] * bidirect
357+
attr_dict["input_forget"] = 0
358+
359+
def name(num, name):
360+
return layer_name + '_' + name + '_' + str(num)
361+
362+
input = x_name
363+
for i in range(num_layers):
364+
w_i, r_i, b_i = cls.concat_params(i, weight_ih, weight_hh, bias_ih, bias_hh, bidirectional, hidden_size)
365+
attr_dict["inputs"] = [input]
366+
w_i_name = name(i, "w")
367+
attr_dict["inputs"].append(w_i_name)
368+
w_i_init = numpy_helper.from_array(w_i, w_i_name)
369+
onnx_init.append(w_i_init)
370+
r_i_name = name(i, 'r')
371+
attr_dict["inputs"].append(r_i_name)
372+
r_i_init = numpy_helper.from_array(r_i, r_i_name)
373+
onnx_init.append(r_i_init)
374+
if b_i is not None:
375+
b_i_name = name(i, 'b')
376+
attr_dict["inputs"].append(b_i_name)
377+
b_i_init = numpy_helper.from_array(b_i, b_i_name)
378+
onnx_init.append(b_i_init)
379+
else:
380+
attr_dict["inputs"].append("")
381+
# add sequence_lens into inputs
382+
if states is not None:
383+
state_hi_name = name(i, 'h')
384+
attr_dict["inputs"].append("")
385+
attr_dict["inputs"].append(state_hi_name)
386+
387+
state_ci_name = name(i, 'c')
388+
attr_dict["inputs"].append(state_ci_name)
389+
state_hi, state_ci = cls.concat_states(i, states, bidirectional)
390+
state_hi_init = numpy_helper.from_array(state_hi, state_hi_name)
391+
onnx_init.append(state_hi_init)
392+
state_ci_init = numpy_helper.from_array(state_ci, state_ci_name)
393+
onnx_init.append(state_ci_init)
394+
395+
attr_dict["outputs"] = [name(i, 'y')]
396+
rnn_node, y_out = make_node(op_type, **attr_dict)
397+
onnx_node.append(rnn_node)
398+
transpose_node, y_out_T = make_node("Transpose", inputs=[y_out], outputs=[y_out + "_T"], perm=[0,2,1,3])
399+
onnx_node.append(transpose_node)
400+
shape = np.array([0, 0, -1], dtype=np.int64)
401+
shape_name = name(i, 'shape')
402+
shape_value = numpy_helper.from_array(shape, shape_name)
403+
onnx_init.append(shape_value)
404+
if i + 1 < num_layers:
405+
reshape_output = [y_out + "_R"]
406+
reshape_node, y_out_R = make_node("Reshape", inputs=[y_out_T, shape_name], outputs=reshape_output)
407+
input = y_out_R
408+
else:
409+
reshape_node, y_out_R = make_node("Reshape", inputs=[y_out_T, shape_name], outputs=[out_name])
410+
onnx_node.append(reshape_node)
411+
412+
return onnx_node, onnx_value, onnx_init

0 commit comments

Comments
 (0)