|
84 | 84 | if net.count_params() != 7160: |
85 | 85 | raise Exception("params dont match") |
86 | 86 |
|
| 87 | +# n_layer=2 |
| 88 | +net = tl.layers.EmbeddingInputlayer(inputs=input_data, vocabulary_size=vocab_size, embedding_size=hidden_size, name='emb2') |
| 89 | +net = tl.layers.BiRNNLayer( |
| 90 | + net, cell_fn=tf.contrib.rnn.BasicLSTMCell, n_hidden=hidden_size, n_steps=num_steps, n_layer=2, return_last=False, return_seq_2d=False, name='birnn2') |
| 91 | + |
| 92 | +net.print_layers() |
| 93 | +net.print_params(False) |
| 94 | + |
| 95 | +shape = net.outputs.get_shape().as_list() |
| 96 | +if shape[1:3] != [num_steps, hidden_size * 2]: |
| 97 | + raise Exception("shape dont match") |
| 98 | + |
| 99 | +if len(net.all_layers) != 2: |
| 100 | + raise Exception("layers dont match") |
| 101 | + |
| 102 | +if len(net.all_params) != 9: |
| 103 | + raise Exception("params dont match") |
| 104 | + |
| 105 | +if net.count_params() != 13720: |
| 106 | + raise Exception("params dont match") |
| 107 | + |
87 | 108 | ## ConvLSTMLayer TODO |
88 | 109 | # image_size = 100 |
89 | 110 | # batch_size = 10 |
|
141 | 162 | if net.count_params() != 4510: |
142 | 163 | raise Exception("params dont match") |
143 | 164 |
|
| 165 | +# n_layer=3 |
| 166 | +nin = tl.layers.EmbeddingInputlayer(inputs=input_seqs, vocabulary_size=vocab_size, embedding_size=embedding_size, name='seq_embedding2') |
| 167 | +rnn = tl.layers.DynamicRNNLayer( |
| 168 | + nin, |
| 169 | + cell_fn=tf.contrib.rnn.BasicLSTMCell, |
| 170 | + n_hidden=embedding_size, |
| 171 | + dropout=(keep_prob if is_train else None), |
| 172 | + sequence_length=tl.layers.retrieve_seq_length_op2(input_seqs), |
| 173 | + n_layer=3, |
| 174 | + return_last=False, |
| 175 | + return_seq_2d=True, |
| 176 | + name='dynamicrnn2') |
| 177 | +net = tl.layers.DenseLayer(rnn, n_units=vocab_size, name="o2") |
| 178 | + |
144 | 179 | ## BiDynamic Synced input and output |
145 | 180 | rnn = tl.layers.BiDynamicRNNLayer( |
146 | 181 | nin, |
|
151 | 186 | return_last=False, |
152 | 187 | return_seq_2d=True, |
153 | 188 | name='bidynamicrnn') |
154 | | -net = tl.layers.DenseLayer(rnn, n_units=vocab_size, name="o2") |
| 189 | +net = tl.layers.DenseLayer(rnn, n_units=vocab_size, name="o3") |
155 | 190 |
|
156 | 191 | net.print_layers() |
157 | 192 | net.print_params(False) |
|
173 | 208 | if net.count_params() != 8390: |
174 | 209 | raise Exception("params dont match") |
175 | 210 |
|
| 211 | +# n_layer=2 |
| 212 | +rnn = tl.layers.BiDynamicRNNLayer( |
| 213 | + nin, |
| 214 | + cell_fn=tf.contrib.rnn.BasicLSTMCell, |
| 215 | + n_hidden=embedding_size, |
| 216 | + dropout=(keep_prob if is_train else None), |
| 217 | + sequence_length=tl.layers.retrieve_seq_length_op2(input_seqs), |
| 218 | + n_layer=2, |
| 219 | + return_last=False, |
| 220 | + return_seq_2d=True, |
| 221 | + name='bidynamicrnn2') |
| 222 | +net = tl.layers.DenseLayer(rnn, n_units=vocab_size, name="o4") |
| 223 | + |
| 224 | +net.print_layers() |
| 225 | +net.print_params(False) |
| 226 | + |
| 227 | +shape = rnn.outputs.get_shape().as_list() |
| 228 | +if shape[-1] != embedding_size * 2: |
| 229 | + raise Exception("shape dont match") |
| 230 | + |
| 231 | +shape = net.outputs.get_shape().as_list() |
| 232 | +if shape[-1] != vocab_size: |
| 233 | + raise Exception("shape dont match") |
| 234 | + |
| 235 | +if len(net.all_layers) != 3: |
| 236 | + raise Exception("layers dont match") |
| 237 | + |
| 238 | +if len(net.all_params) != 11: |
| 239 | + raise Exception("params dont match") |
| 240 | + |
| 241 | +if net.count_params() != 18150: |
| 242 | + raise Exception("params dont match") |
| 243 | + |
176 | 244 | ## Seq2Seq |
177 | 245 | from tensorlayer.layers import EmbeddingInputlayer, Seq2Seq, retrieve_seq_length_op2, DenseLayer |
178 | 246 | batch_size = 32 |
|
198 | 266 | decode_sequence_length=retrieve_seq_length_op2(decode_seqs), |
199 | 267 | initial_state_encode=None, |
200 | 268 | dropout=None, |
201 | | - n_layer=1, |
| 269 | + n_layer=2, |
202 | 270 | return_seq_2d=True, |
203 | 271 | name='Seq2seq') |
204 | 272 | net = DenseLayer(net, n_units=10000, act=tf.identity, name='oo') |
|
215 | 283 | if len(net.all_layers) != 5: |
216 | 284 | raise Exception("layers dont match") |
217 | 285 |
|
218 | | -if len(net.all_params) != 7: |
| 286 | +if len(net.all_params) != 11: |
219 | 287 | raise Exception("params dont match") |
220 | 288 |
|
221 | | -if net.count_params() != 4651600: |
| 289 | +if net.count_params() != 5293200: |
222 | 290 | raise Exception("params dont match") |
0 commit comments