Skip to content

Commit 645f7b8

Browse files
committed
[layers] dynamic rnn in/out
1 parent 335f89b commit 645f7b8

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

tensorlayer/layers.py

100644100755
Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3948,10 +3948,9 @@ def __init__(
39483948
# cell_instance_fn1(),
39493949
# input_keep_prob=in_keep_prob,
39503950
# output_keep_prob=out_keep_prob)
3951-
self.cell = DropoutWrapper_fn(
3952-
self.cell,
3953-
input_keep_prob=in_keep_prob,
3954-
output_keep_prob=out_keep_prob)
3951+
self.cell = DropoutWrapper_fn(self.cell,
3952+
input_keep_prob=in_keep_prob, output_keep_prob=1.0)#out_keep_prob)
3953+
39553954
# Apply multiple layers
39563955
if n_layer > 1:
39573956
try:
@@ -3963,10 +3962,14 @@ def __init__(
39633962
try:
39643963
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)], state_is_tuple=True) # HanSheng
39653964
self.cell = MultiRNNCell_fn([self.cell] * n_layer, state_is_tuple=True)
3966-
except:
3965+
except: # when GRU
39673966
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) # HanSheng
39683967
self.cell = MultiRNNCell_fn([self.cell] * n_layer)
39693968

3969+
if dropout:
3970+
self.cell = DropoutWrapper_fn(self.cell,
3971+
input_keep_prob=1.0, output_keep_prob=out_keep_prob)
3972+
39703973
# self.cell=cell_instance_fn() # HanSheng
39713974

39723975
# Initialize initial_state
@@ -4333,23 +4336,24 @@ class Seq2Seq(Layer):
43334336
>>> decode_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="decode_seqs")
43344337
>>> target_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="target_seqs")
43354338
>>> target_mask = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="target_mask") # tl.prepro.sequences_get_mask()
4336-
>>> with tf.variable_scope("model") as vs:#, reuse=reuse):
4339+
>>> with tf.variable_scope("model"):
43374340
... # for chatbot, you can use the same embedding layer,
43384341
... # for translation, you may want to use 2 seperated embedding layers
4339-
>>> net_encode = EmbeddingInputlayer(
4340-
... inputs = encode_seqs,
4341-
... vocabulary_size = 10000,
4342-
... embedding_size = 200,
4343-
... name = 'seq_embedding')
4344-
>>> vs.reuse_variables()
4345-
>>> tl.layers.set_name_reuse(True)
4346-
>>> net_decode = EmbeddingInputlayer(
4347-
... inputs = decode_seqs,
4348-
... vocabulary_size = 10000,
4349-
... embedding_size = 200,
4350-
... name = 'seq_embedding')
4342+
>>> with tf.variable_scope("embedding") as vs:
4343+
>>> net_encode = EmbeddingInputlayer(
4344+
... inputs = encode_seqs,
4345+
... vocabulary_size = 10000,
4346+
... embedding_size = 200,
4347+
... name = 'seq_embedding')
4348+
>>> vs.reuse_variables()
4349+
>>> tl.layers.set_name_reuse(True)
4350+
>>> net_decode = EmbeddingInputlayer(
4351+
... inputs = decode_seqs,
4352+
... vocabulary_size = 10000,
4353+
... embedding_size = 200,
4354+
... name = 'seq_embedding')
43514355
>>> net = Seq2Seq(net_encode, net_decode,
4352-
... cell_fn = tf.nn.rnn_cell.LSTMCell,
4356+
... cell_fn = tf.contrib.rnn.BasicLSTMCell,
43534357
... n_hidden = 200,
43544358
... initializer = tf.random_uniform_initializer(-0.1, 0.1),
43554359
... encode_sequence_length = retrieve_seq_length_op2(encode_seqs),

0 commit comments

Comments
 (0)