Skip to content

Commit b27f55f

Browse files
committed
[layers] update BiDynamicRNNLayer to be avaliable, compatible with API of BiRNNLayer and DynamicRNNLayer
1 parent 9f602dd commit b27f55f

File tree

1 file changed

+124
-50
lines changed

1 file changed

+124
-50
lines changed

tensorlayer/layers.py

Lines changed: 124 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2702,6 +2702,7 @@ def __init__(
27022702
self.all_layers.extend( [self.outputs] )
27032703
self.all_params.extend( rnn_variables )
27042704

2705+
27052706
# Bidirectional Dynamic RNN
27062707
class BiDynamicRNNLayer(Layer):
27072708
"""
@@ -2722,16 +2723,30 @@ class BiDynamicRNNLayer(Layer):
27222723
The arguments for the cell initializer.
27232724
n_hidden : a int
27242725
The number of hidden units in the layer.
2725-
n_steps : a int
2726-
The sequence length.
2726+
initializer : initializer
2727+
The initializer for initializing the parameters.
2728+
sequence_length : a tensor, array or None
2729+
The sequence length of each row of input data, see ``Advanced Ops for Dynamic RNN``.
2730+
- If None, it uses ``retrieve_seq_length_op`` to compute the sequence_length, i.e. when the features of padding (on right hand side) are all zeros.
2731+
- If using word embedding, you may need to compute the sequence_length from the ID array (the integer features before word embedding) by using ``retrieve_seq_length_op2`` or ``retrieve_seq_length_op``.
2732+
- You can also input an numpy array.
2733+
- More details about TensorFlow dynamic_rnn in `Wild-ML Blog <http://www.wildml.com/2016/08/rnns-in-tensorflow-a-practical-guide-and-undocumented-features/>`_.
2734+
fw_initial_state : None or forward RNN State
2735+
If None, initial_state is zero_state.
2736+
bw_initial_state : None or backward RNN State
2737+
If None, initial_state is zero_state.
2738+
dropout : `tuple` of `float`: (input_keep_prob, output_keep_prob).
2739+
The input and output keep probability.
2740+
n_layer : a int, default is 1.
2741+
The number of RNN layers.
27272742
return_last : boolean
27282743
If True, return the last output, "Sequence input and single output"\n
27292744
If False, return all outputs, "Synced sequence input and output"\n
27302745
In other word, if you want to apply one or more RNN(s) on this layer, set to False.
27312746
return_seq_2d : boolean
2732-
When return_last = False\n
2733-
if True, return 2D Tensor [n_example, n_hidden], for stacking DenseLayer after it.
2734-
if False, return 3D Tensor [n_example/n_steps, n_steps, n_hidden], for stacking multiple RNN after it.
2747+
- When return_last = False
2748+
- If True, return 2D Tensor [n_example, 2 * n_hidden], for stacking DenseLayer or computing cost after it.
2749+
- If False, return 3D Tensor [n_example/n_steps(max), n_steps(max), 2 * n_hidden], for stacking multiple RNN after it.
27352750
name : a string or None
27362751
An optional name to attach to this layer.
27372752
@@ -2740,20 +2755,23 @@ class BiDynamicRNNLayer(Layer):
27402755
outputs : a tensor
27412756
The output of this RNN.
27422757
return_last = False, outputs = all cell_output, which is the hidden state.
2743-
cell_output.get_shape() = (?, n_hidden)
2758+
cell_output.get_shape() = (?, 2 * n_hidden)
27442759
2745-
final_state : a tensor or StateTuple
2760+
fw(bw)_final_state : a tensor or StateTuple
27462761
When state_is_tuple = False,
27472762
it is the final hidden and cell states, states.get_shape() = [?, 2 * n_hidden].\n
27482763
When state_is_tuple = True, it stores two elements: (c, h), in that order.
27492764
You can get the final state after each iteration during training, then
27502765
feed it to the initial state of next iteration.
27512766
2752-
initial_state : a tensor or StateTuple
2767+
fw(bw)_initial_state : a tensor or StateTuple
27532768
It is the initial state of this RNN layer, you can use it to initialize
27542769
your state at the begining of each epoch or iteration according to your
27552770
training procedure.
27562771
2772+
sequence_length : a tensor or array, shape = [batch_size]
2773+
The sequence lengths computed by Advanced Opt or the given sequence lengths.
2774+
27572775
Notes
27582776
-----
27592777
Input dimension should be rank 3 : [batch_size, n_steps(max), n_features], if no, please see :class:`ReshapeLayer`.
@@ -2768,59 +2786,118 @@ def __init__(
27682786
self,
27692787
layer = None,
27702788
cell_fn = tf.nn.rnn_cell.LSTMCell,
2771-
cell_init_args = {'state_is_tuple' : True},
2772-
n_hidden = 64,
2789+
cell_init_args = {},
2790+
n_hidden = 100,
27732791
initializer = tf.random_uniform_initializer(-0.1, 0.1),
2774-
# n_steps = 5,
2792+
sequence_length = None,
2793+
fw_initial_state = None,
2794+
bw_initial_state = None,
2795+
dropout = None,
2796+
n_layer = 1,
27752797
return_last = False,
2776-
# is_reshape = True,
27772798
return_seq_2d = False,
2778-
name = 'birnn_layer',
2799+
name = 'bi_dyrnn_layer',
27792800
):
27802801
Layer.__init__(self, name=name)
27812802
self.inputs = layer.outputs
27822803

2783-
print(" tensorlayer:Instantiate BiDynamicRNNLayer %s: n_hidden:%d, n_steps:%d, in_dim:%d %s, cell_fn:%s " % (self.name, n_hidden,
2784-
n_steps, self.inputs.get_shape().ndims, self.inputs.get_shape(), cell_fn.__name__))
2785-
print(" Untested !!!")
2804+
print(" tensorlayer:Instantiate BiDynamicRNNLayer %s: n_hidden:%d, in_dim:%d %s, cell_fn:%s, dropout:%s, n_layer:%d" %
2805+
(self.name, n_hidden, self.inputs.get_shape().ndims, self.inputs.get_shape(), cell_fn.__name__, dropout, n_layer))
27862806

2787-
self.cell = cell = cell_fn(num_units=n_hidden, **cell_init_args)
2788-
# self.initial_state = cell.zero_state(batch_size, dtype=tf.float32)
2789-
# state = self.initial_state
2807+
# Input dimension should be rank 3 [batch_size, n_steps(max), n_features]
2808+
try:
2809+
self.inputs.get_shape().with_rank(3)
2810+
except:
2811+
raise Exception("RNN : Input dimension should be rank 3 : [batch_size, n_steps(max), n_features]")
2812+
2813+
# Get the batch_size
2814+
fixed_batch_size = self.inputs.get_shape().with_rank_at_least(1)[0]
2815+
if fixed_batch_size.value:
2816+
batch_size = fixed_batch_size.value
2817+
print(" batch_size (concurrent processes): %d" % batch_size)
2818+
else:
2819+
from tensorflow.python.ops import array_ops
2820+
batch_size = array_ops.shape(self.inputs)[0]
2821+
print(" non specified batch_size, uses a tensor instead.")
2822+
self.batch_size = batch_size
27902823

27912824
with tf.variable_scope(name, initializer=initializer) as vs:
2792-
outputs, states = tf.nn.bidirectional_dynamic_rnn(
2793-
cell_fw=cell,
2794-
cell_bw=cell,
2795-
dtype=tf.float64,
2796-
sequence_length=X_lengths,
2797-
inputs=X)
2798-
2799-
output_fw, output_bw = outputs
2800-
states_fw, states_bw = states
2801-
2802-
result = tf.contrib.learn.run_n(
2803-
{"output_fw": output_fw, "output_bw": output_bw, "states_fw": states_fw, "states_bw": states_bw},
2804-
n=1,
2805-
feed_dict=None)
2806-
rnn_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
2825+
# Creats the cell function
2826+
self.fw_cell = cell_fn(num_units=n_hidden, **cell_init_args)
2827+
self.bw_cell = cell_fn(num_units=n_hidden, **cell_init_args)
28072828

2808-
print(" n_params : %d" % (len(rnn_variables)))
2829+
# Apply dropout
2830+
if dropout:
2831+
if type(dropout) in [tuple, list]:
2832+
in_keep_prob = dropout[0]
2833+
out_keep_prob = dropout[1]
2834+
elif isinstance(dropout, float):
2835+
in_keep_prob, out_keep_prob = dropout, dropout
2836+
else:
2837+
raise Exception("Invalid dropout type (must be a 2-D tuple of "
2838+
"float)")
2839+
self.fw_cell = tf.nn.rnn_cell.DropoutWrapper(
2840+
self.fw_cell,
2841+
input_keep_prob=in_keep_prob,
2842+
output_keep_prob=out_keep_prob)
2843+
self.bw_cell = tf.nn.rnn_cell.DropoutWrapper(
2844+
self.bw_cell,
2845+
input_keep_prob=in_keep_prob,
2846+
output_keep_prob=out_keep_prob)
2847+
# Apply multiple layers
2848+
if n_layer > 1:
2849+
print(" n_layer: %d" % n_layer)
2850+
self.fw_cell = tf.nn.rnn_cell.MultiRNNCell([self.fw_cell] * n_layer)
2851+
self.bw_cell = tf.nn.rnn_cell.MultiRNNCell([self.bw_cell] * n_layer)
2852+
# Initial state of RNN
2853+
if fw_initial_state is None:
2854+
self.fw_initial_state = self.fw_cell.zero_state(self.batch_size, dtype=tf.float32)
2855+
else:
2856+
self.fw_initial_state = fw_initial_state
2857+
if bw_initial_state is None:
2858+
self.bw_initial_state = self.bw_cell.zero_state(self.batch_size, dtype=tf.float32)
2859+
else:
2860+
self.bw_initial_state = bw_initial_state
2861+
# Computes sequence_length
2862+
if sequence_length is None:
2863+
sequence_length = retrieve_seq_length_op(
2864+
self.inputs if isinstance(self.inputs, tf.Tensor) else tf.pack(self.inputs))
28092865

2810-
if return_last:
2811-
# 2D Tensor [batch_size, n_hidden]
2812-
self.outputs = output_fw
2813-
else:
2814-
if return_seq_2d:
2815-
# PTB tutorial:
2816-
# 2D Tensor [n_example, n_hidden]
2817-
self.outputs = tf.reshape(tf.concat(1, output_fw), [-1, n_hidden])
2866+
outputs, (states_fw, states_bw) = tf.nn.bidirectional_dynamic_rnn(
2867+
cell_fw=self.fw_cell,
2868+
cell_bw=self.bw_cell,
2869+
inputs=self.inputs,
2870+
sequence_length=sequence_length,
2871+
initial_state_fw=self.fw_initial_state,
2872+
initial_state_bw=self.bw_initial_state,
2873+
)
2874+
rnn_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
2875+
2876+
print(" n_params : %d" % (len(rnn_variables)))
2877+
# Manage the outputs
2878+
outputs = tf.concat(-1, outputs)
2879+
if return_last:
2880+
# [batch_size, 2 * n_hidden]
2881+
self.outputs = advanced_indexing_op(outputs, sequence_length)
28182882
else:
2819-
# <akara>:
2820-
# 3D Tensor [n_example/n_steps, n_steps, n_hidden]
2821-
self.outputs = tf.reshape(tf.concat(1, output_fw), [-1, n_steps, n_hidden])
2883+
# [batch_size, n_step(max), 2 * n_hidden]
2884+
if return_seq_2d:
2885+
# PTB tutorial:
2886+
# 2D Tensor [n_example, 2 * n_hidden]
2887+
self.outputs = tf.reshape(tf.concat(1, outputs), [-1, 2 * n_hidden])
2888+
else:
2889+
# <akara>:
2890+
# 3D Tensor [batch_size, n_steps(max), 2 * n_hidden]
2891+
max_length = tf.shape(outputs)[1]
2892+
batch_size = tf.shape(outputs)[0]
2893+
self.outputs = tf.reshape(tf.concat(1, outputs), [batch_size, max_length, 2 * n_hidden])
2894+
# self.outputs = tf.reshape(tf.concat(1, outputs), [-1, max_length, 2 * n_hidden])
28222895

2823-
self.final_state = state
2896+
# Final state
2897+
self.fw_final_states = states_fw
2898+
self.bw_final_states = states_bw
2899+
2900+
self.sequence_length = sequence_length
28242901

28252902
self.all_layers = list(layer.all_layers)
28262903
self.all_params = list(layer.all_params)
@@ -2830,9 +2907,6 @@ def __init__(
28302907
self.all_params.extend( rnn_variables )
28312908

28322909

2833-
2834-
2835-
28362910
## Shape layer
28372911
class FlattenLayer(Layer):
28382912
"""

0 commit comments

Comments
 (0)