|
89 | 89 | net = tl.layers.BiRNNLayer( |
90 | 90 | net, cell_fn=tf.contrib.rnn.BasicLSTMCell, n_hidden=hidden_size, n_steps=num_steps, n_layer=2, return_last=False, return_seq_2d=False, name='birnn2') |
91 | 91 |
|
92 | | -# net.print_layers() |
93 | | -# net.print_params(False) |
94 | | -# |
95 | | -# shape = net.outputs.get_shape().as_list() |
96 | | -# if shape[1:3] != [num_steps, hidden_size * 2]: |
97 | | -# raise Exception("shape dont match") |
98 | | -# |
99 | | -# if len(net.all_layers) != 2: |
100 | | -# raise Exception("layers dont match") |
101 | | -# |
102 | | -# if len(net.all_params) != 5: |
103 | | -# raise Exception("params dont match") |
104 | | -# |
105 | | -# if net.count_params() != 7160: |
106 | | -# raise Exception("params dont match") |
107 | | -# |
108 | | -# exit() |
| 92 | +net.print_layers() |
| 93 | +net.print_params(False) |
| 94 | + |
| 95 | +shape = net.outputs.get_shape().as_list() |
| 96 | +if shape[1:3] != [num_steps, hidden_size * 2]: |
| 97 | + raise Exception("shape dont match") |
| 98 | + |
| 99 | +if len(net.all_layers) != 2: |
| 100 | + raise Exception("layers dont match") |
| 101 | + |
| 102 | +if len(net.all_params) != 9: |
| 103 | + raise Exception("params dont match") |
| 104 | + |
| 105 | +if net.count_params() != 13720: |
| 106 | + raise Exception("params dont match") |
109 | 107 |
|
110 | 108 | ## ConvLSTMLayer TODO |
111 | 109 | # image_size = 100 |
|
217 | 215 | n_hidden=embedding_size, |
218 | 216 | dropout=(keep_prob if is_train else None), |
219 | 217 | sequence_length=tl.layers.retrieve_seq_length_op2(input_seqs), |
| 218 | + n_layer=2, |
220 | 219 | return_last=False, |
221 | 220 | return_seq_2d=True, |
222 | 221 | name='bidynamicrnn2') |
223 | 222 | net = tl.layers.DenseLayer(rnn, n_units=vocab_size, name="o4") |
224 | 223 |
|
| 224 | +net.print_layers() |
| 225 | +net.print_params(False) |
| 226 | + |
| 227 | +shape = rnn.outputs.get_shape().as_list() |
| 228 | +if shape[-1] != embedding_size * 2: |
| 229 | + raise Exception("shape dont match") |
| 230 | + |
| 231 | +shape = net.outputs.get_shape().as_list() |
| 232 | +if shape[-1] != vocab_size: |
| 233 | + raise Exception("shape dont match") |
| 234 | + |
| 235 | +if len(net.all_layers) != 3: |
| 236 | + raise Exception("layers dont match") |
| 237 | + |
| 238 | +if len(net.all_params) != 11: |
| 239 | + raise Exception("params dont match") |
| 240 | + |
| 241 | +if net.count_params() != 18150: |
| 242 | + raise Exception("params dont match") |
| 243 | + |
225 | 244 | ## Seq2Seq |
226 | 245 | from tensorlayer.layers import EmbeddingInputlayer, Seq2Seq, retrieve_seq_length_op2, DenseLayer |
227 | 246 | batch_size = 32 |
|
0 commit comments