Skip to content

Commit 025e874

Browse files
nebulaVluomai
authored andcommitted
update out_prob of dropout in several RNN layers (#357)
1 parent 2053282 commit 025e874

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

tensorlayer/layers/recurrent.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,10 @@ def __init__(
378378
DropoutWrapper_fn = tf.contrib.rnn.DropoutWrapper
379379
except Exception:
380380
DropoutWrapper_fn = tf.nn.rnn_cell.DropoutWrapper
381-
cell_creator = lambda: DropoutWrapper_fn(rnn_creator(), input_keep_prob=in_keep_prob, output_keep_prob=1.0) # out_keep_prob)
381+
cell_creator = lambda is_last=True: \
382+
DropoutWrapper_fn(rnn_creator(),
383+
input_keep_prob=in_keep_prob,
384+
output_keep_prob=out_keep_prob if is_last else 1.0)
382385
else:
383386
cell_creator = rnn_creator
384387
self.fw_cell = cell_creator()
@@ -392,11 +395,11 @@ def __init__(
392395
MultiRNNCell_fn = tf.nn.rnn_cell.MultiRNNCell
393396

394397
try:
395-
self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
396-
self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
398+
self.fw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
399+
self.bw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
397400
except Exception:
398-
self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
399-
self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
401+
self.fw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])
402+
self.bw_cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])
400403

401404
# Initial state of RNN
402405
if fw_initial_state is None:
@@ -1076,7 +1079,10 @@ def __init__(
10761079
# cell_instance_fn1(),
10771080
# input_keep_prob=in_keep_prob,
10781081
# output_keep_prob=out_keep_prob)
1079-
cell_creator = lambda: DropoutWrapper_fn(rnn_creator(), input_keep_prob=in_keep_prob, output_keep_prob=1.0)
1082+
cell_creator = lambda is_last=True: \
1083+
DropoutWrapper_fn(rnn_creator(),
1084+
input_keep_prob=in_keep_prob,
1085+
output_keep_prob=out_keep_prob if is_last else 1.0)
10801086
else:
10811087
cell_creator = rnn_creator
10821088
self.cell = cell_creator()
@@ -1090,10 +1096,10 @@ def __init__(
10901096
# cell_instance_fn2=cell_instance_fn # HanSheng
10911097
try:
10921098
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)], state_is_tuple=True) # HanSheng
1093-
self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
1099+
self.cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)], state_is_tuple=True)
10941100
except Exception: # when GRU
10951101
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) # HanSheng
1096-
self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
1102+
self.cell = MultiRNNCell_fn([cell_creator(is_last=i == n_layer - 1) for i in range(n_layer)])
10971103

10981104
# self.cell=cell_instance_fn() # HanSheng
10991105

0 commit comments

Comments
 (0)