1616from __future__ import division , print_function , absolute_import
1717
1818import tensorflow as tf
19- from tensorflow .models .rnn import rnn , rnn_cell
2019
2120from skflow .ops import mean_squared_error_regressor , softmax_classifier , dnn
2221
@@ -93,9 +92,102 @@ def dnn_estimator(X, y):
9392 return target_predictor_fn (layers , y )
9493 return dnn_estimator
9594
95+ ## This will be in Tensorflow 0.7.
96+ ## TODO(ilblackdragon): Clean this up when it's released
97+
98+
99+ def _reverse_seq (input_seq , lengths ):
100+ """Reverse a list of Tensors up to specified lengths.
101+ Args:
102+ input_seq: Sequence of seq_len tensors of dimension (batch_size, depth)
103+ lengths: A tensor of dimension batch_size, containing lengths for each
104+ sequence in the batch. If "None" is specified, simply reverses
105+ the list.
106+ Returns:
107+ time-reversed sequence
108+ """
109+ if lengths is None :
110+ return list (reversed (input_seq ))
111+
112+ for input_ in input_seq :
113+ input_ .set_shape (input_ .get_shape ().with_rank (2 ))
114+
115+ # Join into (time, batch_size, depth)
116+ s_joined = tf .pack (input_seq )
117+
118+ # Reverse along dimension 0
119+ s_reversed = tf .reverse_sequence (s_joined , lengths , 0 , 1 )
120+ # Split again into list
121+ result = tf .unpack (s_reversed )
122+ return result
123+
124+
125+ def bidirectional_rnn (cell_fw , cell_bw , inputs ,
126+ initial_state_fw = None , initial_state_bw = None ,
127+ dtype = None , sequence_length = None , scope = None ):
128+ """Creates a bidirectional recurrent neural network.
129+ Similar to the unidirectional case above (rnn) but takes input and builds
130+ independent forward and backward RNNs with the final forward and backward
131+ outputs depth-concatenated, such that the output will have the format
132+ [time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of
133+ forward and backward cell must match. The initial state for both directions
134+ is zero by default (but can be set optionally) and no intermediate states are
135+ ever returned -- the network is fully unrolled for the given (passed in)
136+ length(s) of the sequence(s) or completely unrolled if length(s) is not given.
137+ Args:
138+ cell_fw: An instance of RNNCell, to be used for forward direction.
139+ cell_bw: An instance of RNNCell, to be used for backward direction.
140+ inputs: A length T list of inputs, each a tensor of shape
141+ [batch_size, cell.input_size].
142+ initial_state_fw: (optional) An initial state for the forward RNN.
143+ This must be a tensor of appropriate type and shape
144+ [batch_size x cell.state_size].
145+ initial_state_bw: (optional) Same as for initial_state_fw.
146+ dtype: (optional) The data type for the initial state. Required if either
147+ of the initial states are not provided.
148+ sequence_length: (optional) An int64 vector (tensor) of size [batch_size],
149+ containing the actual lengths for each of the sequences.
150+ scope: VariableScope for the created subgraph; defaults to "BiRNN"
151+ Returns:
152+ A set of output `Tensors` where:
153+ outputs is a length T list of outputs (one for each input), which
154+ are depth-concatenated forward and backward outputs
155+ Raises:
156+ TypeError: If "cell_fw" or "cell_bw" is not an instance of RNNCell.
157+ ValueError: If inputs is None or an empty list.
158+ """
159+
160+ if not isinstance (cell_fw , tf .nn .rnn_cell .RNNCell ):
161+ raise TypeError ("cell_fw must be an instance of RNNCell" )
162+ if not isinstance (cell_bw , tf .nn .rnn_cell .RNNCell ):
163+ raise TypeError ("cell_bw must be an instance of RNNCell" )
164+ if not isinstance (inputs , list ):
165+ raise TypeError ("inputs must be a list" )
166+ if not inputs :
167+ raise ValueError ("inputs must not be empty" )
168+
169+ name = scope or "BiRNN"
170+ # Forward direction
171+ with tf .variable_scope (name + "_FW" ):
172+ output_fw , _ = tf .nn .rnn (cell_fw , inputs , initial_state_fw , dtype ,
173+ sequence_length )
174+
175+ # Backward direction
176+ with tf .variable_scope (name + "_BW" ):
177+ tmp , _ = tf .nn .rnn (cell_bw , _reverse_seq (inputs , sequence_length ),
178+ initial_state_bw , dtype , sequence_length )
179+ output_bw = _reverse_seq (tmp , sequence_length )
180+ # Concat each of the forward/backward outputs
181+ outputs = [tf .concat (1 , [fw , bw ])
182+ for fw , bw in zip (output_fw , output_bw )]
183+
184+ return outputs
185+
186+ # End of Tensorflow 0.7
187+
96188
97189def get_rnn_model (rnn_size , cell_type , num_layers , input_op_fn ,
98- bidirection , target_predictor_fn ,
190+ bidirectional , target_predictor_fn ,
99191 sequence_length , initial_state ):
100192 """Returns a function that creates a RNN TensorFlow subgraph with given
101193 params.
@@ -107,13 +199,14 @@ def get_rnn_model(rnn_size, cell_type, num_layers, input_op_fn,
107199 input_op_fn: Function that will transform the input tensor, such as
108200 creating word embeddings, byte list, etc. This takes
109201 an argument X for input and returns transformed X.
110- bidirection: Whether this is a bidirectional rnn.
202+ bidirectional: boolean, Whether this is a bidirectional rnn.
111203 target_predictor_fn: Function that will predict target from input
112204 features. This can be logistic regression,
113205 linear regression or any other model,
114206 that takes X, y and returns predictions and loss tensors.
115207 sequence_length: If sequence_length is provided, dynamic calculation is performed.
116208 This saves computational time when unrolling past max sequence length.
209+ Required for bidirectional RNNs.
117210 initial_state: An initial state for the RNN. This must be a tensor of appropriate type
118211 and shape [batch_size x cell.state_size].
119212
@@ -124,26 +217,28 @@ def rnn_estimator(X, y):
124217 """RNN estimator with target predictor function on top."""
125218 X = input_op_fn (X )
126219 if cell_type == 'rnn' :
127- cell_fn = rnn_cell .BasicRNNCell
220+ cell_fn = tf . nn . rnn_cell .BasicRNNCell
128221 elif cell_type == 'gru' :
129- cell_fn = rnn_cell .GRUCell
222+ cell_fn = tf . nn . rnn_cell .GRUCell
130223 elif cell_type == 'lstm' :
131- cell_fn = rnn_cell .BasicLSTMCell
224+ cell_fn = tf . nn . rnn_cell .BasicLSTMCell
132225 else :
133226 raise ValueError ("cell_type {} is not supported. " .format (cell_type ))
134- if bidirection :
227+ if bidirectional :
135228 # forward direction cell
136- rnn_fw_cell = rnn_cell .MultiRNNCell ([cell_fn (rnn_size )] * num_layers )
229+ rnn_fw_cell = tf . nn . rnn_cell .MultiRNNCell ([cell_fn (rnn_size )] * num_layers )
137230 # backward direction cell
138- rnn_bw_cell = rnn_cell .MultiRNNCell ([cell_fn (rnn_size )] * num_layers )
231+ rnn_bw_cell = tf . nn . rnn_cell .MultiRNNCell ([cell_fn (rnn_size )] * num_layers )
139232 # pylint: disable=unexpected-keyword-arg, no-value-for-parameter
140- encoding = rnn .bidirectional_rnn (rnn_fw_cell , rnn_bw_cell ,
141- sequence_length = sequence_length ,
142- initial_state = initial_state )
233+ encoding = bidirectional_rnn (rnn_fw_cell , rnn_bw_cell , X ,
234+ dtype = tf .float32 ,
235+ sequence_length = sequence_length ,
236+ initial_state_fw = initial_state ,
237+ initial_state_bw = initial_state )
143238 else :
144- cell = rnn_cell .MultiRNNCell ([cell_fn (rnn_size )] * num_layers )
145- _ , encoding = rnn .rnn (cell , X , dtype = tf .float32 ,
146- sequence_length = sequence_length ,
147- initial_state = initial_state )
239+ cell = tf . nn . rnn_cell .MultiRNNCell ([cell_fn (rnn_size )] * num_layers )
240+ _ , encoding = tf . nn .rnn (cell , X , dtype = tf .float32 ,
241+ sequence_length = sequence_length ,
242+ initial_state = initial_state )
148243 return target_predictor_fn (encoding [- 1 ], y )
149244 return rnn_estimator
0 commit comments