1515from six .moves import xrange
1616import random , warnings
1717import copy
18-
18+ import inspect
1919# __all__ = [
2020# "Layer",
2121# "DenseLayer",
@@ -3397,7 +3397,10 @@ def __init__(
33973397 # for input_ in tf.split(1, num_steps, inputs)]
33983398 # outputs, state = rnn.rnn(cell, inputs, initial_state=self._initial_state)
33993399 outputs = []
3400- self .cell = cell = cell_fn (num_units = n_hidden , ** cell_init_args )
3400+ if 'reuse' in inspect .getargspec (cell_fn .__init__ ).args :
3401+ self .cell = cell = cell_fn (num_units = n_hidden , reuse = tf .get_variable_scope ().reuse , ** cell_init_args )
3402+ else :
3403+ self .cell = cell = cell_fn (num_units = n_hidden , ** cell_init_args )
34013404 if initial_state is None :
34023405 self .initial_state = cell .zero_state (batch_size , dtype = tf .float32 ) # 1.2.3
34033406 state = self .initial_state
@@ -3560,8 +3563,7 @@ def __init__(
35603563 raise Exception ("RNN : Input dimension should be rank 3 : [batch_size, n_steps, n_features]" )
35613564
35623565 with tf .variable_scope (name , initializer = initializer ) as vs :
3563- self .fw_cell = cell_fn (num_units = n_hidden , ** cell_init_args )
3564- self .bw_cell = cell_fn (num_units = n_hidden , ** cell_init_args )
3566+ rnn_creator = lambda : cell_fn (num_units = n_hidden , ** cell_init_args )
35653567 # Apply dropout
35663568 if dropout :
35673569 if type (dropout ) in [tuple , list ]:
@@ -3576,14 +3578,14 @@ def __init__(
35763578 DropoutWrapper_fn = tf .contrib .rnn .DropoutWrapper
35773579 except :
35783580 DropoutWrapper_fn = tf .nn .rnn_cell .DropoutWrapper
3579- self . fw_cell = DropoutWrapper_fn (
3580- self . fw_cell ,
3581- input_keep_prob = in_keep_prob ,
3582- output_keep_prob = out_keep_prob )
3583- self . bw_cell = DropoutWrapper_fn (
3584- self .bw_cell ,
3585- input_keep_prob = in_keep_prob ,
3586- output_keep_prob = out_keep_prob )
3581+ cell_creator = lambda : DropoutWrapper_fn (rnn_creator (),
3582+ input_keep_prob = in_keep_prob ,
3583+ output_keep_prob = 1.0 ) # out_keep_prob)
3584+ else :
3585+ cell_creator = rnn_creator
3586+ self .fw_cell = cell_creator ()
3587+ self . bw_cell = cell_creator ()
3588+
35873589 # Apply multiple layers
35883590 if n_layer > 1 :
35893591 try : # TF1.0
@@ -3592,13 +3594,11 @@ def __init__(
35923594 MultiRNNCell_fn = tf .nn .rnn_cell .MultiRNNCell
35933595
35943596 try :
3595- self .fw_cell = MultiRNNCell_fn ([self .fw_cell ] * n_layer ,
3596- state_is_tuple = True )
3597- self .bw_cell = MultiRNNCell_fn ([self .bw_cell ] * n_layer ,
3598- state_is_tuple = True )
3597+ self .fw_cell = MultiRNNCell_fn ([cell_creator () for _ in range (n_layer )], state_is_tuple = True )
3598+ self .bw_cell = MultiRNNCell_fn ([cell_creator () for _ in range (n_layer )], state_is_tuple = True )
35993599 except :
3600- self .fw_cell = MultiRNNCell_fn ([self . fw_cell ] * n_layer )
3601- self .bw_cell = MultiRNNCell_fn ([self . bw_cell ] * n_layer )
3600+ self .fw_cell = MultiRNNCell_fn ([cell_creator () for _ in range ( n_layer )] )
3601+ self .bw_cell = MultiRNNCell_fn ([cell_creator () for _ in range ( n_layer )] )
36023602
36033603 # Initial state of RNN
36043604 if fw_initial_state is None :
@@ -3938,7 +3938,7 @@ def __init__(
39383938
39393939 # Creats the cell function
39403940 # cell_instance_fn=lambda: cell_fn(num_units=n_hidden, **cell_init_args) # HanSheng
3941- self . cell = cell_fn (num_units = n_hidden , ** cell_init_args )
3941+ rnn_creator = lambda : cell_fn (num_units = n_hidden , ** cell_init_args )
39423942
39433943 # Apply dropout
39443944 if dropout :
@@ -3960,9 +3960,11 @@ def __init__(
39603960 # cell_instance_fn1(),
39613961 # input_keep_prob=in_keep_prob,
39623962 # output_keep_prob=out_keep_prob)
3963- self . cell = DropoutWrapper_fn (self . cell ,
3963+ cell_creator = lambda : DropoutWrapper_fn (rnn_creator () ,
39643964 input_keep_prob = in_keep_prob , output_keep_prob = 1.0 )#out_keep_prob)
3965-
3965+ else :
3966+ cell_creator = rnn_creator
3967+ self .cell = cell_creator ()
39663968 # Apply multiple layers
39673969 if n_layer > 1 :
39683970 try :
@@ -3973,10 +3975,10 @@ def __init__(
39733975 # cell_instance_fn2=cell_instance_fn # HanSheng
39743976 try :
39753977 # cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)], state_is_tuple=True) # HanSheng
3976- self .cell = MultiRNNCell_fn ([self . cell ] * n_layer , state_is_tuple = True )
3978+ self .cell = MultiRNNCell_fn ([cell_creator () for _ in range ( n_layer )] , state_is_tuple = True )
39773979 except : # when GRU
39783980 # cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) # HanSheng
3979- self .cell = MultiRNNCell_fn ([self . cell ] * n_layer )
3981+ self .cell = MultiRNNCell_fn ([cell_creator () for _ in range ( n_layer )] )
39803982
39813983 if dropout :
39823984 self .cell = DropoutWrapper_fn (self .cell ,
@@ -4179,8 +4181,7 @@ def __init__(
41794181 with tf .variable_scope (name , initializer = initializer ) as vs :
41804182 # Creats the cell function
41814183 # cell_instance_fn=lambda: cell_fn(num_units=n_hidden, **cell_init_args) # HanSheng
4182- self .fw_cell = cell_fn (num_units = n_hidden , ** cell_init_args )
4183- self .bw_cell = cell_fn (num_units = n_hidden , ** cell_init_args )
4184+ rnn_creator = lambda : cell_fn (num_units = n_hidden , ** cell_init_args )
41844185
41854186 # Apply dropout
41864187 if dropout :
@@ -4202,15 +4203,13 @@ def __init__(
42024203 # cell_instance_fn1(),
42034204 # input_keep_prob=in_keep_prob,
42044205 # output_keep_prob=out_keep_prob)
4205-
4206- self .fw_cell = DropoutWrapper_fn (
4207- self .fw_cell ,
4208- input_keep_prob = in_keep_prob ,
4209- output_keep_prob = out_keep_prob )
4210- self .bw_cell = DropoutWrapper_fn (
4211- self .bw_cell ,
4212- input_keep_prob = in_keep_prob ,
4213- output_keep_prob = out_keep_prob )
4206+ cell_creator = lambda : DropoutWrapper_fn (rnn_creator (),
4207+ input_keep_prob = in_keep_prob ,
4208+ output_keep_prob = 1.0 ) # out_keep_prob)
4209+ else :
4210+ cell_creator = rnn_creator
4211+ self .fw_cell = cell_creator ()
4212+ self .bw_cell = cell_creator ()
42144213 # Apply multiple layers
42154214 if n_layer > 1 :
42164215 try :
@@ -4220,8 +4219,8 @@ def __init__(
42204219
42214220 # cell_instance_fn2=cell_instance_fn # HanSheng
42224221 # cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)])
4223- self .fw_cell = MultiRNNCell_fn ([self . fw_cell ] * n_layer )
4224- self .bw_cell = MultiRNNCell_fn ([self . bw_cell ] * n_layer )
4222+ self .fw_cell = MultiRNNCell_fn ([cell_creator () for _ in range ( n_layer )] )
4223+ self .bw_cell = MultiRNNCell_fn ([cell_creator () for _ in range ( n_layer )] )
42254224 # self.fw_cell=cell_instance_fn()
42264225 # self.bw_cell=cell_instance_fn()
42274226 # Initial state of RNN
@@ -5256,17 +5255,17 @@ def sampled_loss(inputs, labels):
52565255 # ============ Seq Encode Layer =============
52575256 # Create the internal multi-layer cell for our RNN.
52585257 try : # TF1.0
5259- single_cell = tf .contrib .rnn .GRUCell (size )
5258+ cell_creator = lambda : tf .contrib .rnn .GRUCell (size )
52605259 except :
5261- single_cell = tf .nn .rnn_cell .GRUCell (size )
5260+ cell_creator = lambda : tf .nn .rnn_cell .GRUCell (size )
52625261
52635262 if use_lstm :
52645263 try : # TF1.0
5265- single_cell = tf .contrib .rnn .BasicLSTMCell (size )
5264+ cell_creator = lambda : tf .contrib .rnn .BasicLSTMCell (size )
52665265 except :
5267- single_cell = tf .nn .rnn_cell .BasicLSTMCell (size )
5266+ cell_creator = lambda : tf .nn .rnn_cell .BasicLSTMCell (size )
52685267
5269- cell = single_cell
5268+ cell = cell_creator ()
52705269 if num_layers > 1 :
52715270 try : # TF1.0
52725271 cell = tf .contrib .rnn .MultiRNNCell ([single_cell ] * num_layers )
0 commit comments