@@ -378,7 +378,10 @@ def __init__(
378378 DropoutWrapper_fn = tf .contrib .rnn .DropoutWrapper
379379 except Exception :
380380 DropoutWrapper_fn = tf .nn .rnn_cell .DropoutWrapper
381- cell_creator = lambda : DropoutWrapper_fn (rnn_creator (), input_keep_prob = in_keep_prob , output_keep_prob = 1.0 ) # out_keep_prob)
381+ cell_creator = lambda is_last = True : \
382+ DropoutWrapper_fn (rnn_creator (),
383+ input_keep_prob = in_keep_prob ,
384+ output_keep_prob = out_keep_prob if is_last else 1.0 )
382385 else :
383386 cell_creator = rnn_creator
384387 self .fw_cell = cell_creator ()
@@ -392,11 +395,11 @@ def __init__(
392395 MultiRNNCell_fn = tf .nn .rnn_cell .MultiRNNCell
393396
394397 try :
395- self .fw_cell = MultiRNNCell_fn ([cell_creator () for _ in range (n_layer )], state_is_tuple = True )
396- self .bw_cell = MultiRNNCell_fn ([cell_creator () for _ in range (n_layer )], state_is_tuple = True )
398+ self .fw_cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )], state_is_tuple = True )
399+ self .bw_cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )], state_is_tuple = True )
397400 except Exception :
398- self .fw_cell = MultiRNNCell_fn ([cell_creator () for _ in range (n_layer )])
399- self .bw_cell = MultiRNNCell_fn ([cell_creator () for _ in range (n_layer )])
401+ self .fw_cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )])
402+ self .bw_cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )])
400403
401404 # Initial state of RNN
402405 if fw_initial_state is None :
@@ -1076,7 +1079,10 @@ def __init__(
10761079 # cell_instance_fn1(),
10771080 # input_keep_prob=in_keep_prob,
10781081 # output_keep_prob=out_keep_prob)
1079- cell_creator = lambda : DropoutWrapper_fn (rnn_creator (), input_keep_prob = in_keep_prob , output_keep_prob = 1.0 )
1082+ cell_creator = lambda is_last = True : \
1083+ DropoutWrapper_fn (rnn_creator (),
1084+ input_keep_prob = in_keep_prob ,
1085+ output_keep_prob = out_keep_prob if is_last else 1.0 )
10801086 else :
10811087 cell_creator = rnn_creator
10821088 self .cell = cell_creator ()
@@ -1090,10 +1096,10 @@ def __init__(
10901096 # cell_instance_fn2=cell_instance_fn # HanSheng
10911097 try :
10921098 # cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)], state_is_tuple=True) # HanSheng
1093- self .cell = MultiRNNCell_fn ([cell_creator () for _ in range (n_layer )], state_is_tuple = True )
1099+ self .cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )], state_is_tuple = True )
10941100 except Exception : # when GRU
10951101 # cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) # HanSheng
1096- self .cell = MultiRNNCell_fn ([cell_creator () for _ in range (n_layer )])
1102+ self .cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )])
10971103
10981104 # self.cell=cell_instance_fn() # HanSheng
10991105
0 commit comments