Skip to content

Commit bc6b855

Browse files
authored
Merge pull request #96 from pnpnpn/fix_GRUCell
Attention Seq2Seq Wrapper : tf.contrib.rnn.* issues for TF1.0 by @pnpnpn
2 parents 04e43f1 + fbc593e commit bc6b855

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

tensorlayer/layers.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4678,12 +4678,23 @@ def sampled_loss(inputs, labels):
46784678

46794679
# ============ Seq Encode Layer =============
46804680
# Create the internal multi-layer cell for our RNN.
4681-
single_cell = tf.nn.rnn_cell.GRUCell(size)
4681+
try: # TF1.0
4682+
single_cell = tf.contrib.rnn.GRUCell(size)
4683+
except:
4684+
single_cell = tf.nn.rnn_cell.GRUCell(size)
4685+
46824686
if use_lstm:
4683-
single_cell = tf.nn.rnn_cell.BasicLSTMCell(size)
4687+
try: # TF1.0
4688+
single_cell = tf.contrib.rnn.BasicLSTMCell(size)
4689+
except:
4690+
single_cell = tf.nn.rnn_cell.BasicLSTMCell(size)
4691+
46844692
cell = single_cell
46854693
if num_layers > 1:
4686-
cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * num_layers)
4694+
try: # TF1.0
4695+
cell = tf.contrib.rnn.MultiRNNCell([single_cell] * num_layers)
4696+
except:
4697+
cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * num_layers)
46874698

46884699
# ============== Seq Decode Layer ============
46894700
# The seq2seq function: we use embedding for the input and attention.

0 commit comments

Comments
 (0)