Skip to content

Commit c72b053

Browse files
committed
fix rnn dropout bug
1 parent 8ebde5a commit c72b053

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

tensorlayer/layers/recurrent.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -379,13 +379,20 @@ def __init__(
379379
MultiRNNCell_fn = tf.contrib.rnn.MultiRNNCell
380380
except Exception:
381381
MultiRNNCell_fn = tf.nn.rnn_cell.MultiRNNCell
382-
383-
try:
384-
self.fw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
385-
self.bw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
386-
except Exception:
387-
self.fw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])
388-
self.bw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])
382+
if dropout:
383+
try:
384+
self.fw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
385+
self.bw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
386+
except Exception:
387+
self.fw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])
388+
self.bw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])
389+
else:
390+
try:
391+
self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
392+
self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
393+
except Exception:
394+
self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
395+
self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
389396

390397
# Initial state of RNN
391398
if fw_initial_state is None:
@@ -1081,12 +1088,18 @@ def __init__(
10811088
MultiRNNCell_fn = tf.nn.rnn_cell.MultiRNNCell
10821089

10831090
# cell_instance_fn2=cell_instance_fn # HanSheng
1084-
try:
1085-
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)], state_is_tuple=True) # HanSheng
1086-
self.cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
1087-
except Exception: # when GRU
1088-
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) # HanSheng
1089-
self.cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])
1091+
if dropout:
1092+
try:
1093+
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)], state_is_tuple=True) # HanSheng
1094+
self.cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
1095+
except Exception: # when GRU
1096+
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) # HanSheng
1097+
self.cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])
1098+
else:
1099+
try:
1100+
self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
1101+
except Exception: # when GRU
1102+
self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
10901103

10911104
# self.cell=cell_instance_fn() # HanSheng
10921105

@@ -1338,8 +1351,12 @@ def __init__(
13381351
sequence_length = retrieve_seq_length_op(self.inputs if isinstance(self.inputs, tf.Tensor) else tf.pack(self.inputs))
13391352

13401353
if n_layer > 1:
1341-
self.fw_cell = [cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)]
1342-
self.bw_cell = [cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)]
1354+
if dropout:
1355+
self.fw_cell = [cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)]
1356+
self.bw_cell = [cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)]
1357+
else:
1358+
self.fw_cell = [cell_creator() for _ in range(n_layer)]
1359+
self.bw_cell = [cell_creator() for _ in range(n_layer)]
13431360
from tensorflow.contrib.rnn import stack_bidirectional_dynamic_rnn
13441361
outputs, states_fw, states_bw = stack_bidirectional_dynamic_rnn(
13451362
cells_fw=self.fw_cell,

tests/test_layers_recurrent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@
220220
return_last=False,
221221
return_seq_2d=True,
222222
name='bidynamicrnn2')
223-
net = tl.layers.DenseLayer(rnn, n_units=vocab_size, name="o3")
223+
net = tl.layers.DenseLayer(rnn, n_units=vocab_size, name="o4")
224224

225225
## Seq2Seq
226226
from tensorlayer.layers import EmbeddingInputlayer, Seq2Seq, retrieve_seq_length_op2, DenseLayer

0 commit comments

Comments
 (0)