Skip to content

Commit 9e837da

Browse files
authored
Merge pull request #295 from matthew-z/stack_bidynamic
use stack_bidirectional_dynamic_rnn for multi-layers BiDynamicRNN
2 parents c361052 + 36ff5f9 commit 9e837da

File tree

1 file changed

+38
-33
lines changed

1 file changed

+38
-33
lines changed

tensorlayer/layers.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5623,53 +5623,58 @@ def __init__(
56235623
# cell_instance_fn1(),
56245624
# input_keep_prob=in_keep_prob,
56255625
# output_keep_prob=out_keep_prob)
5626-
cell_creator = lambda: DropoutWrapper_fn(rnn_creator(), input_keep_prob=in_keep_prob, output_keep_prob=1.0) # out_keep_prob)
5626+
cell_creator = lambda is_last=True: \
5627+
DropoutWrapper_fn(rnn_creator(),
5628+
input_keep_prob=in_keep_prob,
5629+
output_keep_prob=out_keep_prob if is_last else 1.0) # out_keep_prob)
56275630
else:
5628-
cell_creator = rnn_creator
5629-
self.fw_cell = cell_creator()
5630-
self.bw_cell = cell_creator()
5631-
# Apply multiple layers
5632-
if n_layer > 1:
5633-
try:
5634-
MultiRNNCell_fn = tf.contrib.rnn.MultiRNNCell
5635-
except:
5636-
MultiRNNCell_fn = tf.nn.rnn_cell.MultiRNNCell
5631+
cell_creator = lambda : rnn_creator()
56375632

5638-
# cell_instance_fn2=cell_instance_fn # HanSheng
5639-
# cell_instance_fn=lambda: MultiRNNCell_fn([cell_instance_fn2() for _ in range(n_layer)])
5640-
self.fw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
5641-
self.bw_cell = MultiRNNCell_fn([cell_creator() for _ in range(n_layer)])
56425633

5643-
if dropout:
5644-
self.fw_cell = DropoutWrapper_fn(self.fw_cell, input_keep_prob=1.0, output_keep_prob=out_keep_prob)
5645-
self.bw_cell = DropoutWrapper_fn(self.bw_cell, input_keep_prob=1.0, output_keep_prob=out_keep_prob)
5634+
# if dropout:
5635+
# self.fw_cell = DropoutWrapper_fn(self.fw_cell, input_keep_prob=1.0, output_keep_prob=out_keep_prob)
5636+
# self.bw_cell = DropoutWrapper_fn(self.bw_cell, input_keep_prob=1.0, output_keep_prob=out_keep_prob)
56465637

56475638
# self.fw_cell=cell_instance_fn()
56485639
# self.bw_cell=cell_instance_fn()
56495640
# Initial state of RNN
5650-
if fw_initial_state is None:
5651-
self.fw_initial_state = self.fw_cell.zero_state(self.batch_size, dtype=D_TYPE) # dtype=tf.float32)
5652-
else:
5653-
self.fw_initial_state = fw_initial_state
5654-
if bw_initial_state is None:
5655-
self.bw_initial_state = self.bw_cell.zero_state(self.batch_size, dtype=D_TYPE) # dtype=tf.float32)
5656-
else:
5657-
self.bw_initial_state = bw_initial_state
5641+
5642+
self.fw_initial_state = fw_initial_state
5643+
self.bw_initial_state = bw_initial_state
56585644
# Computes sequence_length
56595645
if sequence_length is None:
56605646
try: ## TF1.0
56615647
sequence_length = retrieve_seq_length_op(self.inputs if isinstance(self.inputs, tf.Tensor) else tf.stack(self.inputs))
56625648
except: ## TF0.12
56635649
sequence_length = retrieve_seq_length_op(self.inputs if isinstance(self.inputs, tf.Tensor) else tf.pack(self.inputs))
56645650

5665-
outputs, (states_fw, states_bw) = tf.nn.bidirectional_dynamic_rnn(
5666-
cell_fw=self.fw_cell,
5667-
cell_bw=self.bw_cell,
5668-
inputs=self.inputs,
5669-
sequence_length=sequence_length,
5670-
initial_state_fw=self.fw_initial_state,
5671-
initial_state_bw=self.bw_initial_state,
5672-
**dynamic_rnn_init_args)
5651+
if n_layer > 1:
5652+
self.fw_cell = [cell_creator(is_last= i == n_layer - 1) for i in range(n_layer)]
5653+
self.bw_cell = [cell_creator(is_last= i == n_layer - 1) for i in range(n_layer)]
5654+
from tensorflow.contrib.rnn import stack_bidirectional_dynamic_rnn
5655+
outputs, states_fw, states_bw = stack_bidirectional_dynamic_rnn(
5656+
cells_fw=self.fw_cell,
5657+
cells_bw=self.bw_cell,
5658+
inputs=self.inputs,
5659+
sequence_length=sequence_length,
5660+
initial_states_fw=self.fw_initial_state,
5661+
initial_states_bw=self.bw_initial_state,
5662+
dtype=D_TYPE,
5663+
**dynamic_rnn_init_args)
5664+
5665+
else:
5666+
self.fw_cell = cell_creator()
5667+
self.bw_cell = cell_creator()
5668+
outputs, (states_fw, states_bw) = tf.nn.bidirectional_dynamic_rnn(
5669+
cell_fw=self.fw_cell,
5670+
cell_bw=self.bw_cell,
5671+
inputs=self.inputs,
5672+
sequence_length=sequence_length,
5673+
initial_state_fw=self.fw_initial_state,
5674+
initial_state_bw=self.bw_initial_state,
5675+
dtype=D_TYPE,
5676+
**dynamic_rnn_init_args)
5677+
56735678
rnn_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
56745679

56755680
print(" n_params : %d" % (len(rnn_variables)))

0 commit comments

Comments
 (0)