Skip to content

Commit 686e2c7

Browse files
committed
fixed test code for rnn
1 parent a60fed3 commit 686e2c7

File tree

1 file changed

+36
-17
lines changed

1 file changed

+36
-17
lines changed

tests/test_layers_recurrent.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -89,23 +89,21 @@
8989
net = tl.layers.BiRNNLayer(
9090
net, cell_fn=tf.contrib.rnn.BasicLSTMCell, n_hidden=hidden_size, n_steps=num_steps, n_layer=2, return_last=False, return_seq_2d=False, name='birnn2')
9191

92-
# net.print_layers()
93-
# net.print_params(False)
94-
#
95-
# shape = net.outputs.get_shape().as_list()
96-
# if shape[1:3] != [num_steps, hidden_size * 2]:
97-
# raise Exception("shape dont match")
98-
#
99-
# if len(net.all_layers) != 2:
100-
# raise Exception("layers dont match")
101-
#
102-
# if len(net.all_params) != 5:
103-
# raise Exception("params dont match")
104-
#
105-
# if net.count_params() != 7160:
106-
# raise Exception("params dont match")
107-
#
108-
# exit()
92+
net.print_layers()
93+
net.print_params(False)
94+
95+
shape = net.outputs.get_shape().as_list()
96+
if shape[1:3] != [num_steps, hidden_size * 2]:
97+
raise Exception("shape dont match")
98+
99+
if len(net.all_layers) != 2:
100+
raise Exception("layers dont match")
101+
102+
if len(net.all_params) != 9:
103+
raise Exception("params dont match")
104+
105+
if net.count_params() != 13720:
106+
raise Exception("params dont match")
109107

110108
## ConvLSTMLayer TODO
111109
# image_size = 100
@@ -217,11 +215,32 @@
217215
n_hidden=embedding_size,
218216
dropout=(keep_prob if is_train else None),
219217
sequence_length=tl.layers.retrieve_seq_length_op2(input_seqs),
218+
n_layer=2,
220219
return_last=False,
221220
return_seq_2d=True,
222221
name='bidynamicrnn2')
223222
net = tl.layers.DenseLayer(rnn, n_units=vocab_size, name="o4")
224223

224+
net.print_layers()
225+
net.print_params(False)
226+
227+
shape = rnn.outputs.get_shape().as_list()
228+
if shape[-1] != embedding_size * 2:
229+
raise Exception("shape dont match")
230+
231+
shape = net.outputs.get_shape().as_list()
232+
if shape[-1] != vocab_size:
233+
raise Exception("shape dont match")
234+
235+
if len(net.all_layers) != 3:
236+
raise Exception("layers dont match")
237+
238+
if len(net.all_params) != 11:
239+
raise Exception("params dont match")
240+
241+
if net.count_params() != 18150:
242+
raise Exception("params dont match")
243+
225244
## Seq2Seq
226245
from tensorlayer.layers import EmbeddingInputlayer, Seq2Seq, retrieve_seq_length_op2, DenseLayer
227246
batch_size = 32

0 commit comments

Comments
 (0)