@@ -5623,53 +5623,58 @@ def __init__(
56235623 # cell_instance_fn1(),
56245624 # input_keep_prob=in_keep_prob,
56255625 # output_keep_prob=out_keep_prob)
5626- cell_creator = lambda : DropoutWrapper_fn (rnn_creator (), input_keep_prob = in_keep_prob , output_keep_prob = 1.0 ) # out_keep_prob)
5626+ cell_creator = lambda is_last = True : \
5627+ DropoutWrapper_fn (rnn_creator (),
5628+ input_keep_prob = in_keep_prob ,
5629+ output_keep_prob = out_keep_prob if is_last else 1.0 ) # out_keep_prob)
56275630 else :
5628- cell_creator = rnn_creator
5629- self .fw_cell = cell_creator ()
5630- self .bw_cell = cell_creator ()
5631- # Apply multiple layers
5632- if n_layer > 1 :
5633- try :
5634- MultiRNNCell_fn = tf .contrib .rnn .MultiRNNCell
5635- except :
5636- MultiRNNCell_fn = tf .nn .rnn_cell .MultiRNNCell
5631+ cell_creator = lambda : rnn_creator ()
56375632
5638- # cell_instance_fn2=cell_instance_fn # HanSheng
5639- # cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)])
5640- self .fw_cell = MultiRNNCell_fn ([cell_creator () for _ in range (n_layer )])
5641- self .bw_cell = MultiRNNCell_fn ([cell_creator () for _ in range (n_layer )])
56425633
5643- if dropout :
5644- self .fw_cell = DropoutWrapper_fn (self .fw_cell , input_keep_prob = 1.0 , output_keep_prob = out_keep_prob )
5645- self .bw_cell = DropoutWrapper_fn (self .bw_cell , input_keep_prob = 1.0 , output_keep_prob = out_keep_prob )
5634+ # if dropout:
5635+ # self.fw_cell = DropoutWrapper_fn(self.fw_cell, input_keep_prob=1.0, output_keep_prob=out_keep_prob)
5636+ # self.bw_cell = DropoutWrapper_fn(self.bw_cell, input_keep_prob=1.0, output_keep_prob=out_keep_prob)
56465637
56475638 # self.fw_cell=cell_instance_fn()
56485639 # self.bw_cell=cell_instance_fn()
56495640 # Initial state of RNN
5650- if fw_initial_state is None :
5651- self .fw_initial_state = self .fw_cell .zero_state (self .batch_size , dtype = D_TYPE ) # dtype=tf.float32)
5652- else :
5653- self .fw_initial_state = fw_initial_state
5654- if bw_initial_state is None :
5655- self .bw_initial_state = self .bw_cell .zero_state (self .batch_size , dtype = D_TYPE ) # dtype=tf.float32)
5656- else :
5657- self .bw_initial_state = bw_initial_state
5641+
5642+ self .fw_initial_state = fw_initial_state
5643+ self .bw_initial_state = bw_initial_state
56585644 # Computes sequence_length
56595645 if sequence_length is None :
56605646 try : ## TF1.0
56615647 sequence_length = retrieve_seq_length_op (self .inputs if isinstance (self .inputs , tf .Tensor ) else tf .stack (self .inputs ))
56625648 except : ## TF0.12
56635649 sequence_length = retrieve_seq_length_op (self .inputs if isinstance (self .inputs , tf .Tensor ) else tf .pack (self .inputs ))
56645650
5665- outputs , (states_fw , states_bw ) = tf .nn .bidirectional_dynamic_rnn (
5666- cell_fw = self .fw_cell ,
5667- cell_bw = self .bw_cell ,
5668- inputs = self .inputs ,
5669- sequence_length = sequence_length ,
5670- initial_state_fw = self .fw_initial_state ,
5671- initial_state_bw = self .bw_initial_state ,
5672- ** dynamic_rnn_init_args )
5651+ if n_layer > 1 :
5652+ self .fw_cell = [cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )]
5653+ self .bw_cell = [cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )]
5654+ from tensorflow .contrib .rnn import stack_bidirectional_dynamic_rnn
5655+ outputs , states_fw , states_bw = stack_bidirectional_dynamic_rnn (
5656+ cells_fw = self .fw_cell ,
5657+ cells_bw = self .bw_cell ,
5658+ inputs = self .inputs ,
5659+ sequence_length = sequence_length ,
5660+ initial_states_fw = self .fw_initial_state ,
5661+ initial_states_bw = self .bw_initial_state ,
5662+ dtype = D_TYPE ,
5663+ ** dynamic_rnn_init_args )
5664+
5665+ else :
5666+ self .fw_cell = cell_creator ()
5667+ self .bw_cell = cell_creator ()
5668+ outputs , (states_fw , states_bw ) = tf .nn .bidirectional_dynamic_rnn (
5669+ cell_fw = self .fw_cell ,
5670+ cell_bw = self .bw_cell ,
5671+ inputs = self .inputs ,
5672+ sequence_length = sequence_length ,
5673+ initial_state_fw = self .fw_initial_state ,
5674+ initial_state_bw = self .bw_initial_state ,
5675+ dtype = D_TYPE ,
5676+ ** dynamic_rnn_init_args )
5677+
56735678 rnn_variables = tf .get_collection (TF_GRAPHKEYS_VARIABLES , scope = vs .name )
56745679
56755680 print (" n_params : %d" % (len (rnn_variables )))
0 commit comments