Skip to content

Commit 7ef916f

Browse files
authored
Merge pull request #391 from tensorlayer/fix-rnn-bug
update test code for recurrent & fix the bug for n_layer>1
2 parents b2e6ccc + f740f02 commit 7ef916f

File tree

2 files changed

+104
-19
lines changed

2 files changed

+104
-19
lines changed

tensorlayer/layers/recurrent.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -379,13 +379,20 @@ def __init__(
379379
MultiRNNCell_fn = tf.contrib.rnn.MultiRNNCell
380380
except Exception:
381381
MultiRNNCell_fn = tf.nn.rnn_cell.MultiRNNCell
382-
383-
try:
384-
self.fw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
385-
self.bw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
386-
except Exception:
387-
self.fw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])
388-
self.bw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])
382+
if dropout:
383+
try:
384+
self.fw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
385+
self.bw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
386+
except Exception:
387+
self.fw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])
388+
self.bw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])
389+
else:
390+
try:
391+
self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
392+
self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
393+
except Exception:
394+
self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
395+
self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
389396

390397
# Initial state of RNN
391398
if fw_initial_state is None:
@@ -1081,12 +1088,18 @@ def __init__(
10811088
MultiRNNCell_fn = tf.nn.rnn_cell.MultiRNNCell
10821089

10831090
# cell_instance_fn2=cell_instance_fn # HanSheng
1084-
try:
1085-
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)], state_is_tuple=True) # HanSheng
1086-
self.cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
1087-
except Exception: # when GRU
1088-
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) # HanSheng
1089-
self.cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])
1091+
if dropout:
1092+
try:
1093+
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)], state_is_tuple=True) # HanSheng
1094+
self.cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
1095+
except Exception: # when GRU
1096+
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) # HanSheng
1097+
self.cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])
1098+
else:
1099+
try:
1100+
self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
1101+
except Exception: # when GRU
1102+
self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
10901103

10911104
# self.cell=cell_instance_fn() # HanSheng
10921105

@@ -1338,8 +1351,12 @@ def __init__(
13381351
sequence_length = retrieve_seq_length_op(self.inputs if isinstance(self.inputs, tf.Tensor) else tf.pack(self.inputs))
13391352

13401353
if n_layer > 1:
1341-
self.fw_cell = [cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)]
1342-
self.bw_cell = [cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)]
1354+
if dropout:
1355+
self.fw_cell = [cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)]
1356+
self.bw_cell = [cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)]
1357+
else:
1358+
self.fw_cell = [cell_creator() for _ in range(n_layer)]
1359+
self.bw_cell = [cell_creator() for _ in range(n_layer)]
13431360
from tensorflow.contrib.rnn import stack_bidirectional_dynamic_rnn
13441361
outputs, states_fw, states_bw = stack_bidirectional_dynamic_rnn(
13451362
cells_fw=self.fw_cell,

tests/test_layers_recurrent.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,27 @@
8484
if net.count_params() != 7160:
8585
raise Exception("params dont match")
8686

87+
# n_layer=2
88+
net = tl.layers.EmbeddingInputlayer(inputs=input_data, vocabulary_size=vocab_size, embedding_size=hidden_size, name='emb2')
89+
net = tl.layers.BiRNNLayer(
90+
net, cell_fn=tf.contrib.rnn.BasicLSTMCell, n_hidden=hidden_size, n_steps=num_steps, n_layer=2, return_last=False, return_seq_2d=False, name='birnn2')
91+
92+
net.print_layers()
93+
net.print_params(False)
94+
95+
shape = net.outputs.get_shape().as_list()
96+
if shape[1:3] != [num_steps, hidden_size * 2]:
97+
raise Exception("shape dont match")
98+
99+
if len(net.all_layers) != 2:
100+
raise Exception("layers dont match")
101+
102+
if len(net.all_params) != 9:
103+
raise Exception("params dont match")
104+
105+
if net.count_params() != 13720:
106+
raise Exception("params dont match")
107+
87108
## ConvLSTMLayer TODO
88109
# image_size = 100
89110
# batch_size = 10
@@ -141,6 +162,20 @@
141162
if net.count_params() != 4510:
142163
raise Exception("params dont match")
143164

165+
# n_layer=3
166+
nin = tl.layers.EmbeddingInputlayer(inputs=input_seqs, vocabulary_size=vocab_size, embedding_size=embedding_size, name='seq_embedding2')
167+
rnn = tl.layers.DynamicRNNLayer(
168+
nin,
169+
cell_fn=tf.contrib.rnn.BasicLSTMCell,
170+
n_hidden=embedding_size,
171+
dropout=(keep_prob if is_train else None),
172+
sequence_length=tl.layers.retrieve_seq_length_op2(input_seqs),
173+
n_layer=3,
174+
return_last=False,
175+
return_seq_2d=True,
176+
name='dynamicrnn2')
177+
net = tl.layers.DenseLayer(rnn, n_units=vocab_size, name="o2")
178+
144179
## BiDynamic Synced input and output
145180
rnn = tl.layers.BiDynamicRNNLayer(
146181
nin,
@@ -151,7 +186,7 @@
151186
return_last=False,
152187
return_seq_2d=True,
153188
name='bidynamicrnn')
154-
net = tl.layers.DenseLayer(rnn, n_units=vocab_size, name="o2")
189+
net = tl.layers.DenseLayer(rnn, n_units=vocab_size, name="o3")
155190

156191
net.print_layers()
157192
net.print_params(False)
@@ -173,6 +208,39 @@
173208
if net.count_params() != 8390:
174209
raise Exception("params dont match")
175210

211+
# n_layer=2
212+
rnn = tl.layers.BiDynamicRNNLayer(
213+
nin,
214+
cell_fn=tf.contrib.rnn.BasicLSTMCell,
215+
n_hidden=embedding_size,
216+
dropout=(keep_prob if is_train else None),
217+
sequence_length=tl.layers.retrieve_seq_length_op2(input_seqs),
218+
n_layer=2,
219+
return_last=False,
220+
return_seq_2d=True,
221+
name='bidynamicrnn2')
222+
net = tl.layers.DenseLayer(rnn, n_units=vocab_size, name="o4")
223+
224+
net.print_layers()
225+
net.print_params(False)
226+
227+
shape = rnn.outputs.get_shape().as_list()
228+
if shape[-1] != embedding_size * 2:
229+
raise Exception("shape dont match")
230+
231+
shape = net.outputs.get_shape().as_list()
232+
if shape[-1] != vocab_size:
233+
raise Exception("shape dont match")
234+
235+
if len(net.all_layers) != 3:
236+
raise Exception("layers dont match")
237+
238+
if len(net.all_params) != 11:
239+
raise Exception("params dont match")
240+
241+
if net.count_params() != 18150:
242+
raise Exception("params dont match")
243+
176244
## Seq2Seq
177245
from tensorlayer.layers import EmbeddingInputlayer, Seq2Seq, retrieve_seq_length_op2, DenseLayer
178246
batch_size = 32
@@ -198,7 +266,7 @@
198266
decode_sequence_length=retrieve_seq_length_op2(decode_seqs),
199267
initial_state_encode=None,
200268
dropout=None,
201-
n_layer=1,
269+
n_layer=2,
202270
return_seq_2d=True,
203271
name='Seq2seq')
204272
net = DenseLayer(net, n_units=10000, act=tf.identity, name='oo')
@@ -215,8 +283,8 @@
215283
if len(net.all_layers) != 5:
216284
raise Exception("layers dont match")
217285

218-
if len(net.all_params) != 7:
286+
if len(net.all_params) != 11:
219287
raise Exception("params dont match")
220288

221-
if net.count_params() != 4651600:
289+
if net.count_params() != 5293200:
222290
raise Exception("params dont match")

0 commit comments

Comments
 (0)