@@ -379,13 +379,20 @@ def __init__(
379379 MultiRNNCell_fn = tf .contrib .rnn .MultiRNNCell
380380 except Exception :
381381 MultiRNNCell_fn = tf .nn .rnn_cell .MultiRNNCell
382-
383- try :
384- self .fw_cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )], state_is_tuple = True )
385- self .bw_cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )], state_is_tuple = True )
386- except Exception :
387- self .fw_cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )])
388- self .bw_cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )])
382+ if dropout :
383+ try :
384+ self .fw_cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )], state_is_tuple = True )
385+ self .bw_cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )], state_is_tuple = True )
386+ except Exception :
387+ self .fw_cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )])
388+ self .bw_cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )])
389+ else :
390+ try :
391+ self .fw_cell = MultiRNNCell_fn ([cell_creator () for _ in range (n_layer )], state_is_tuple = True )
392+ self .bw_cell = MultiRNNCell_fn ([cell_creator () for _ in range (n_layer )], state_is_tuple = True )
393+ except Exception :
394+ self .fw_cell = MultiRNNCell_fn ([cell_creator () for _ in range (n_layer )])
395+ self .bw_cell = MultiRNNCell_fn ([cell_creator () for _ in range (n_layer )])
389396
390397 # Initial state of RNN
391398 if fw_initial_state is None :
@@ -1081,12 +1088,18 @@ def __init__(
10811088 MultiRNNCell_fn = tf .nn .rnn_cell .MultiRNNCell
10821089
10831090 # cell_instance_fn2=cell_instance_fn # HanSheng
1084- try :
1085- # cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)], state_is_tuple=True) # HanSheng
1086- self .cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )], state_is_tuple = True )
1087- except Exception : # when GRU
1088- # cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) # HanSheng
1089- self .cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )])
1091+ if dropout :
1092+ try :
1093+ # cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)], state_is_tuple=True) # HanSheng
1094+ self .cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )], state_is_tuple = True )
1095+ except Exception : # when GRU
1096+ # cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)]) # HanSheng
1097+ self .cell = MultiRNNCell_fn ([cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )])
1098+ else :
1099+ try :
1100+ self .cell = MultiRNNCell_fn ([cell_creator () for _ in range (n_layer )], state_is_tuple = True )
1101+ except Exception : # when GRU
1102+ self .cell = MultiRNNCell_fn ([cell_creator () for _ in range (n_layer )])
10901103
10911104 # self.cell=cell_instance_fn() # HanSheng
10921105
@@ -1338,8 +1351,12 @@ def __init__(
13381351 sequence_length = retrieve_seq_length_op (self .inputs if isinstance (self .inputs , tf .Tensor ) else tf .pack (self .inputs ))
13391352
13401353 if n_layer > 1 :
1341- self .fw_cell = [cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )]
1342- self .bw_cell = [cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )]
1354+ if dropout :
1355+ self .fw_cell = [cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )]
1356+ self .bw_cell = [cell_creator (is_last = i == n_layer - 1 ) for i in range (n_layer )]
1357+ else :
1358+ self .fw_cell = [cell_creator () for _ in range (n_layer )]
1359+ self .bw_cell = [cell_creator () for _ in range (n_layer )]
13431360 from tensorflow .contrib .rnn import stack_bidirectional_dynamic_rnn
13441361 outputs , states_fw , states_bw = stack_bidirectional_dynamic_rnn (
13451362 cells_fw = self .fw_cell ,
0 commit comments