Skip to content

Commit 8ebde5a

Browse files
committed
update test code for recurrent
1 parent b2e6ccc commit 8ebde5a

File tree

1 file changed

+51
-2
lines changed

1 file changed

+51
-2
lines changed

tests/test_layers_recurrent.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,29 @@
8484
if net.count_params() != 7160:
8585
raise Exception("params dont match")
8686

87+
# n_layer=2
88+
net = tl.layers.EmbeddingInputlayer(inputs=input_data, vocabulary_size=vocab_size, embedding_size=hidden_size, name='emb2')
89+
net = tl.layers.BiRNNLayer(
90+
net, cell_fn=tf.contrib.rnn.BasicLSTMCell, n_hidden=hidden_size, n_steps=num_steps, n_layer=2, return_last=False, return_seq_2d=False, name='birnn2')
91+
92+
# net.print_layers()
93+
# net.print_params(False)
94+
#
95+
# shape = net.outputs.get_shape().as_list()
96+
# if shape[1:3] != [num_steps, hidden_size * 2]:
97+
# raise Exception("shape dont match")
98+
#
99+
# if len(net.all_layers) != 2:
100+
# raise Exception("layers dont match")
101+
#
102+
# if len(net.all_params) != 5:
103+
# raise Exception("params dont match")
104+
#
105+
# if net.count_params() != 7160:
106+
# raise Exception("params dont match")
107+
#
108+
# exit()
109+
87110
## ConvLSTMLayer TODO
88111
# image_size = 100
89112
# batch_size = 10
@@ -141,6 +164,20 @@
141164
if net.count_params() != 4510:
142165
raise Exception("params dont match")
143166

167+
# n_layer=3
168+
nin = tl.layers.EmbeddingInputlayer(inputs=input_seqs, vocabulary_size=vocab_size, embedding_size=embedding_size, name='seq_embedding2')
169+
rnn = tl.layers.DynamicRNNLayer(
170+
nin,
171+
cell_fn=tf.contrib.rnn.BasicLSTMCell,
172+
n_hidden=embedding_size,
173+
dropout=(keep_prob if is_train else None),
174+
sequence_length=tl.layers.retrieve_seq_length_op2(input_seqs),
175+
n_layer=3,
176+
return_last=False,
177+
return_seq_2d=True,
178+
name='dynamicrnn2')
179+
net = tl.layers.DenseLayer(rnn, n_units=vocab_size, name="o2")
180+
144181
## BiDynamic Synced input and output
145182
rnn = tl.layers.BiDynamicRNNLayer(
146183
nin,
@@ -151,7 +188,7 @@
151188
return_last=False,
152189
return_seq_2d=True,
153190
name='bidynamicrnn')
154-
net = tl.layers.DenseLayer(rnn, n_units=vocab_size, name="o2")
191+
net = tl.layers.DenseLayer(rnn, n_units=vocab_size, name="o3")
155192

156193
net.print_layers()
157194
net.print_params(False)
@@ -173,6 +210,18 @@
173210
if net.count_params() != 8390:
174211
raise Exception("params dont match")
175212

213+
# n_layer=2
214+
rnn = tl.layers.BiDynamicRNNLayer(
215+
nin,
216+
cell_fn=tf.contrib.rnn.BasicLSTMCell,
217+
n_hidden=embedding_size,
218+
dropout=(keep_prob if is_train else None),
219+
sequence_length=tl.layers.retrieve_seq_length_op2(input_seqs),
220+
return_last=False,
221+
return_seq_2d=True,
222+
name='bidynamicrnn2')
223+
net = tl.layers.DenseLayer(rnn, n_units=vocab_size, name="o3")
224+
176225
## Seq2Seq
177226
from tensorlayer.layers import EmbeddingInputlayer, Seq2Seq, retrieve_seq_length_op2, DenseLayer
178227
batch_size = 32
@@ -198,7 +247,7 @@
198247
decode_sequence_length=retrieve_seq_length_op2(decode_seqs),
199248
initial_state_encode=None,
200249
dropout=None,
201-
n_layer=1,
250+
n_layer=2,
202251
return_seq_2d=True,
203252
name='Seq2seq')
204253
net = DenseLayer(net, n_units=10000, act=tf.identity, name='oo')

0 commit comments

Comments
 (0)