Skip to content

Commit b77676a

Browse files
committed
edit for TF1.0
1 parent 46deb2c commit b77676a

File tree

1 file changed

+41
-9
lines changed

1 file changed

+41
-9
lines changed

tensorlayer/layers.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3080,11 +3080,19 @@ def __init__(
30803080
if return_seq_2d:
30813081
# PTB tutorial: stack dense layer after that, or compute the cost from the output
30823082
# 2D Tensor [n_example, n_hidden]
3083-
self.outputs = tf.reshape(tf.concat(1, outputs), [-1, n_hidden])
3083+
try: # TF1.0
3084+
self.outputs = tf.reshape(tf.concat(outputs, 1), [-1, n_hidden])
3085+
except: # TF0.12
3086+
self.outputs = tf.reshape(tf.concat(1, outputs), [-1, n_hidden])
3087+
3088+
30843089
else:
30853090
# <akara>: stack more RNN layer after that
30863091
# 3D Tensor [n_example/n_steps, n_steps, n_hidden]
3087-
self.outputs = tf.reshape(tf.concat(1, outputs), [-1, n_steps, n_hidden])
3092+
try: # TF1.0
3093+
self.outputs = tf.reshape(tf.concat(outputs, 1), [-1, n_steps, n_hidden])
3094+
except: # TF0.12
3095+
self.outputs = tf.reshape(tf.concat(1, outputs), [-1, n_steps, n_hidden])
30883096

30893097
self.final_state = state
30903098

@@ -3271,11 +3279,18 @@ def __init__(
32713279
self.outputs = outputs
32723280
if return_seq_2d:
32733281
# 2D Tensor [n_example, n_hidden]
3274-
self.outputs = tf.reshape(tf.concat(1, outputs), [-1, n_hidden*2])
3282+
try: # TF1.0
3283+
self.outputs = tf.reshape(tf.concat(outputs, 1), [-1, n_hidden*2])
3284+
except: # TF0.12
3285+
self.outputs = tf.reshape(tf.concat(1, outputs), [-1, n_hidden*2])
32753286
else:
32763287
# <akara>: stack more RNN layer after that
32773288
# 3D Tensor [n_example/n_steps, n_steps, n_hidden]
3278-
self.outputs = tf.reshape(tf.concat(1, outputs), [-1, n_steps, n_hidden*2])
3289+
3290+
try: # TF1.0
3291+
self.outputs = tf.reshape(tf.concat(outputs,1), [-1, n_steps, n_hidden*2])
3292+
except: # TF0.12
3293+
self.outputs = tf.reshape(tf.concat(1, outputs), [-1, n_steps, n_hidden*2])
32793294
self.fw_final_state = fw_state
32803295
self.bw_final_state = bw_state
32813296

@@ -3606,13 +3621,21 @@ def __init__(
36063621
if return_seq_2d:
36073622
# PTB tutorial:
36083623
# 2D Tensor [n_example, n_hidden]
3609-
self.outputs = tf.reshape(tf.concat(1, outputs), [-1, n_hidden])
3624+
try: # TF1.0
3625+
self.outputs = tf.reshape(tf.concat(outputs, 1), [-1, n_hidden])
3626+
except: # TF0.12
3627+
self.outputs = tf.reshape(tf.concat(1, outputs), [-1, n_hidden])
36103628
else:
36113629
# <akara>:
36123630
# 3D Tensor [batch_size, n_steps(max), n_hidden]
36133631
max_length = tf.shape(outputs)[1]
36143632
batch_size = tf.shape(outputs)[0]
3615-
self.outputs = tf.reshape(tf.concat(1, outputs), [batch_size, max_length, n_hidden])
3633+
3634+
3635+
try: # TF1.0
3636+
self.outputs = tf.reshape(tf.concat(outputs, 1), [batch_size, max_length, n_hidden])
3637+
except: # TF0.12
3638+
self.outputs = tf.reshape(tf.concat(1, outputs), [batch_size, max_length, n_hidden])
36163639
# self.outputs = tf.reshape(tf.concat(1, outputs), [-1, max_length, n_hidden])
36173640

36183641
# Final state
@@ -3806,7 +3829,10 @@ def __init__(
38063829

38073830
print(" n_params : %d" % (len(rnn_variables)))
38083831
# Manage the outputs
3809-
outputs = tf.concat(2, outputs)
3832+
try: # TF1.0
3833+
outputs = tf.concat(outputs, 2)
3834+
except: # TF0.12
3835+
outputs = tf.concat(2, outputs)
38103836
if return_last:
38113837
# [batch_size, 2 * n_hidden]
38123838
self.outputs = advanced_indexing_op(outputs, sequence_length)
@@ -3815,13 +3841,19 @@ def __init__(
38153841
if return_seq_2d:
38163842
# PTB tutorial:
38173843
# 2D Tensor [n_example, 2 * n_hidden]
3818-
self.outputs = tf.reshape(tf.concat(1, outputs), [-1, 2 * n_hidden])
3844+
try: # TF1.0
3845+
self.outputs = tf.reshape(tf.concat(outputs, 1), [-1, 2 * n_hidden])
3846+
except: # TF0.12
3847+
self.outputs = tf.reshape(tf.concat(1, outputs), [-1, 2 * n_hidden])
38193848
else:
38203849
# <akara>:
38213850
# 3D Tensor [batch_size, n_steps(max), 2 * n_hidden]
38223851
max_length = tf.shape(outputs)[1]
38233852
batch_size = tf.shape(outputs)[0]
3824-
self.outputs = tf.reshape(tf.concat(1, outputs), [batch_size, max_length, 2 * n_hidden])
3853+
try: # TF1.0
3854+
self.outputs = tf.reshape(tf.concat(outputs, 1), [batch_size, max_length, 2 * n_hidden])
3855+
except: # TF0.12
3856+
self.outputs = tf.reshape(tf.concat(1, outputs), [batch_size, max_length, 2 * n_hidden])
38253857
# self.outputs = tf.reshape(tf.concat(1, outputs), [-1, max_length, 2 * n_hidden])
38263858

38273859
# Final state

0 commit comments

Comments
 (0)