Skip to content

Commit 9f2af2e

Browse files
authored
Merge pull request #147 from CTTC/master
make the usage of rnn cells compatible with TF1.1
2 parents e4273f3 + dcef471 commit 9f2af2e

File tree

1 file changed

+40
-41
lines changed

1 file changed

+40
-41
lines changed

tensorlayer/layers.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from six.moves import xrange
1616
import random, warnings
1717
import copy
18-
18+
import inspect
1919
# __all__ = [
2020
# "Layer",
2121
# "DenseLayer",
@@ -3397,7 +3397,10 @@ def __init__(
33973397
# for input_ in tf.split(1, num_steps, inputs)]
33983398
# outputs, state = rnn.rnn(cell, inputs, initial_state=self._initial_state)
33993399
outputs = []
3400-
self.cell = cell = cell_fn(num_units=n_hidden, **cell_init_args)
3400+
if 'reuse' in inspect.getargspec(cell_fn.__init__).args:
3401+
self.cell = cell = cell_fn(num_units=n_hidden, reuse=tf.get_variable_scope().reuse, **cell_init_args)
3402+
else:
3403+
self.cell = cell = cell_fn(num_units=n_hidden, **cell_init_args)
34013404
if initial_state is None:
34023405
self.initial_state = cell.zero_state(batch_size, dtype=tf.float32) # 1.2.3
34033406
state = self.initial_state
@@ -3560,8 +3563,7 @@ def __init__(
35603563
raise Exception("RNN : Input dimension should be rank 3 : [batch_size, n_steps, n_features]")
35613564

35623565
with tf.variable_scope(name, initializer=initializer) as vs:
3563-
self.fw_cell = cell_fn(num_units=n_hidden, **cell_init_args)
3564-
self.bw_cell = cell_fn(num_units=n_hidden, **cell_init_args)
3566+
rnn_creator = lambda: cell_fn(num_units=n_hidden, **cell_init_args)
35653567
# Apply dropout
35663568
if dropout:
35673569
if type(dropout) in [tuple, list]:
@@ -3576,14 +3578,14 @@ def __init__(
35763578
DropoutWrapper_fn = tf.contrib.rnn.DropoutWrapper
35773579
except:
35783580
DropoutWrapper_fn = tf.nn.rnn_cell.DropoutWrapper
3579-
self.fw_cell = DropoutWrapper_fn(
3580-
self.fw_cell,
3581-
input_keep_prob=in_keep_prob,
3582-
output_keep_prob=out_keep_prob)
3583-
self.bw_cell = DropoutWrapper_fn(
3584-
self.bw_cell,
3585-
input_keep_prob=in_keep_prob,
3586-
output_keep_prob=out_keep_prob)
3581+
cell_creator = lambda: DropoutWrapper_fn(rnn_creator(),
3582+
input_keep_prob=in_keep_prob,
3583+
output_keep_prob=1.0) # out_keep_prob)
3584+
else:
3585+
cell_creator = rnn_creator
3586+
self.fw_cell = cell_creator()
3587+
self.bw_cell = cell_creator()
3588+
35873589
# Apply multiple layers
35883590
if n_layer > 1:
35893591
try: # TF1.0
@@ -3592,13 +3594,11 @@ def __init__(
35923594
MultiRNNCell_fn = tf.nn.rnn_cell.MultiRNNCell
35933595

35943596
try:
3595-
self.fw_cell = MultiRNNCell_fn([self.fw_cell] * n_layer,
3596-
state_is_tuple=True)
3597-
self.bw_cell = MultiRNNCell_fn([self.bw_cell] * n_layer,
3598-
state_is_tuple=True)
3597+
self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
3598+
self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
35993599
except:
3600-
self.fw_cell = MultiRNNCell_fn([self.fw_cell] * n_layer)
3601-
self.bw_cell = MultiRNNCell_fn([self.bw_cell] * n_layer)
3600+
self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
3601+
self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
36023602

36033603
# Initial state of RNN
36043604
if fw_initial_state is None:
@@ -3938,7 +3938,7 @@ def __init__(
39383938

39393939
# Creats the cell function
39403940
# cell_instance_fn=lambda: cell_fn(num_units=n_hidden, **cell_init_args) # HanSheng
3941-
self.cell = cell_fn(num_units=n_hidden, **cell_init_args)
3941+
rnn_creator = lambda: cell_fn(num_units=n_hidden, **cell_init_args)
39423942

39433943
# Apply dropout
39443944
if dropout:
@@ -3960,9 +3960,11 @@ def __init__(
39603960
# cell_instance_fn1(),
39613961
# input_keep_prob=in_keep_prob,
39623962
# output_keep_prob=out_keep_prob)
3963-
self.cell = DropoutWrapper_fn(self.cell,
3963+
cell_creator = lambda: DropoutWrapper_fn(rnn_creator(),
39643964
input_keep_prob=in_keep_prob, output_keep_prob=1.0)#out_keep_prob)
3965-
3965+
else:
3966+
cell_creator = rnn_creator
3967+
self.cell = cell_creator()
39663968
# Apply multiple layers
39673969
if n_layer > 1:
39683970
try:
@@ -3973,10 +3975,10 @@ def __init__(
39733975
# cell_instance_fn2=cell_instance_fn # HanSheng
39743976
try:
39753977
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)], state_is_tuple=True) # HanSheng
3976-
self.cell = MultiRNNCell_fn([self.cell] * n_layer, state_is_tuple=True)
3978+
self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)], state_is_tuple=True)
39773979
except: # when GRU
39783980
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) # HanSheng
3979-
self.cell = MultiRNNCell_fn([self.cell] * n_layer)
3981+
self.cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
39803982

39813983
if dropout:
39823984
self.cell = DropoutWrapper_fn(self.cell,
@@ -4179,8 +4181,7 @@ def __init__(
41794181
with tf.variable_scope(name, initializer=initializer) as vs:
41804182
# Creats the cell function
41814183
# cell_instance_fn=lambda: cell_fn(num_units=n_hidden, **cell_init_args) # HanSheng
4182-
self.fw_cell = cell_fn(num_units=n_hidden, **cell_init_args)
4183-
self.bw_cell = cell_fn(num_units=n_hidden, **cell_init_args)
4184+
rnn_creator = lambda: cell_fn(num_units=n_hidden, **cell_init_args)
41844185

41854186
# Apply dropout
41864187
if dropout:
@@ -4202,15 +4203,13 @@ def __init__(
42024203
# cell_instance_fn1(),
42034204
# input_keep_prob=in_keep_prob,
42044205
# output_keep_prob=out_keep_prob)
4205-
4206-
self.fw_cell = DropoutWrapper_fn(
4207-
self.fw_cell,
4208-
input_keep_prob=in_keep_prob,
4209-
output_keep_prob=out_keep_prob)
4210-
self.bw_cell = DropoutWrapper_fn(
4211-
self.bw_cell,
4212-
input_keep_prob=in_keep_prob,
4213-
output_keep_prob=out_keep_prob)
4206+
cell_creator = lambda: DropoutWrapper_fn(rnn_creator(),
4207+
input_keep_prob=in_keep_prob,
4208+
output_keep_prob=1.0) # out_keep_prob)
4209+
else:
4210+
cell_creator = rnn_creator
4211+
self.fw_cell = cell_creator()
4212+
self.bw_cell = cell_creator()
42144213
# Apply multiple layers
42154214
if n_layer > 1:
42164215
try:
@@ -4220,8 +4219,8 @@ def __init__(
42204219

42214220
# cell_instance_fn2=cell_instance_fn # HanSheng
42224221
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)])
4223-
self.fw_cell = MultiRNNCell_fn([self.fw_cell] * n_layer)
4224-
self.bw_cell = MultiRNNCell_fn([self.bw_cell] * n_layer)
4222+
self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
4223+
self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
42254224
# self.fw_cell=cell_instance_fn()
42264225
# self.bw_cell=cell_instance_fn()
42274226
# Initial state of RNN
@@ -5256,17 +5255,17 @@ def sampled_loss(inputs, labels):
52565255
# ============ Seq Encode Layer =============
52575256
# Create the internal multi-layer cell for our RNN.
52585257
try: # TF1.0
5259-
single_cell = tf.contrib.rnn.GRUCell(size)
5258+
cell_creator = lambda: tf.contrib.rnn.GRUCell(size)
52605259
except:
5261-
single_cell = tf.nn.rnn_cell.GRUCell(size)
5260+
cell_creator = lambda: tf.nn.rnn_cell.GRUCell(size)
52625261

52635262
if use_lstm:
52645263
try: # TF1.0
5265-
single_cell = tf.contrib.rnn.BasicLSTMCell(size)
5264+
cell_creator = lambda: tf.contrib.rnn.BasicLSTMCell(size)
52665265
except:
5267-
single_cell = tf.nn.rnn_cell.BasicLSTMCell(size)
5266+
cell_creator = lambda: tf.nn.rnn_cell.BasicLSTMCell(size)
52685267

5269-
cell = single_cell
5268+
cell = cell_creator()
52705269
if num_layers > 1:
52715270
try: # TF1.0
52725271
cell = tf.contrib.rnn.MultiRNNCell([single_cell] * num_layers)

0 commit comments

Comments
 (0)