Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Commit 6ebacbc

Browse files
committed
Added small test for bidirectional rnn and forked code from current github version of TF to make it work. Will cleanup when 0.7 is out
1 parent bb239fa commit 6ebacbc

File tree

3 files changed

+136
-20
lines changed

3 files changed

+136
-20
lines changed

skflow/estimators/rnn.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class TensorFlowRNNClassifier(TensorFlowEstimator, ClassifierMixin):
3535
input_op_fn: Function that will transform the input tensor, such as
3636
creating word embeddings, byte list, etc. This takes
3737
an argument X for input and returns transformed X.
38-
bidirection: Whether this is a bidirectional rnn.
38+
bidirectional: boolean, Whether this is a bidirectional rnn.
3939
sequence_length: If sequence_length is provided, dynamic calculation is performed.
4040
This saves computational time when unrolling past max sequence length.
4141
initial_state: An initial state for the RNN. This must be a tensor of appropriate type
@@ -71,7 +71,7 @@ def exp_decay(global_step):
7171

7272
def __init__(self, rnn_size, n_classes, cell_type='gru', num_layers=1,
7373
input_op_fn=null_input_op_fn,
74-
initial_state=None, bidirection=False,
74+
initial_state=None, bidirectional=False,
7575
sequence_length=None, tf_master="", batch_size=32,
7676
steps=50, optimizer="SGD", learning_rate=0.1,
7777
tf_random_seed=42, continue_training=False,
@@ -80,7 +80,7 @@ def __init__(self, rnn_size, n_classes, cell_type='gru', num_layers=1,
8080
self.rnn_size = rnn_size
8181
self.cell_type = cell_type
8282
self.input_op_fn = input_op_fn
83-
self.bidirection = bidirection
83+
self.bidirectional = bidirectional
8484
self.num_layers = num_layers
8585
self.sequence_length = sequence_length
8686
self.initial_state = initial_state
@@ -97,7 +97,7 @@ def __init__(self, rnn_size, n_classes, cell_type='gru', num_layers=1,
9797
def _model_fn(self, X, y):
9898
return models.get_rnn_model(self.rnn_size, self.cell_type,
9999
self.num_layers,
100-
self.input_op_fn, self.bidirection,
100+
self.input_op_fn, self.bidirectional,
101101
models.logistic_regression,
102102
self.sequence_length,
103103
self.initial_state)(X, y)
@@ -123,7 +123,7 @@ class TensorFlowRNNRegressor(TensorFlowEstimator, RegressorMixin):
123123
input_op_fn: Function that will transform the input tensor, such as
124124
creating word embeddings, byte list, etc. This takes
125125
an argument X for input and returns transformed X.
126-
bidirection: Whether this is a bidirectional rnn.
126+
bidirectional: boolean, Whether this is a bidirectional rnn.
127127
sequence_length: If sequence_length is provided, dynamic calculation is performed.
128128
This saves computational time when unrolling past max sequence length.
129129
initial_state: An initial state for the RNN. This must be a tensor of appropriate type
@@ -152,7 +152,7 @@ def exp_decay(global_step):
152152

153153
def __init__(self, rnn_size, cell_type='gru', num_layers=1,
154154
input_op_fn=null_input_op_fn, initial_state=None,
155-
bidirection=False, sequence_length=None,
155+
bidirectional=False, sequence_length=None,
156156
n_classes=0, tf_master="", batch_size=32,
157157
steps=50, optimizer="SGD", learning_rate=0.1,
158158
tf_random_seed=42, continue_training=False,
@@ -161,7 +161,7 @@ def __init__(self, rnn_size, cell_type='gru', num_layers=1,
161161
self.rnn_size = rnn_size
162162
self.cell_type = cell_type
163163
self.input_op_fn = input_op_fn
164-
self.bidirection = bidirection
164+
self.bidirectional = bidirectional
165165
self.num_layers = num_layers
166166
self.sequence_length = sequence_length
167167
self.initial_state = initial_state
@@ -178,7 +178,7 @@ def __init__(self, rnn_size, cell_type='gru', num_layers=1,
178178
def _model_fn(self, X, y):
179179
return models.get_rnn_model(self.rnn_size, self.cell_type,
180180
self.num_layers,
181-
self.input_op_fn, self.bidirection,
181+
self.input_op_fn, self.bidirectional,
182182
models.linear_regression,
183183
self.sequence_length,
184184
self.initial_state)(X, y)

skflow/models.py

Lines changed: 108 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,102 @@ def dnn_estimator(X, y):
9292
return target_predictor_fn(layers, y)
9393
return dnn_estimator
9494

95+
## This will be in Tensorflow 0.7.
96+
## TODO(ilblackdragon): Clean this up when it's released
97+
98+
99+
def _reverse_seq(input_seq, lengths):
100+
"""Reverse a list of Tensors up to specified lengths.
101+
Args:
102+
input_seq: Sequence of seq_len tensors of dimension (batch_size, depth)
103+
lengths: A tensor of dimension batch_size, containing lengths for each
104+
sequence in the batch. If "None" is specified, simply reverses
105+
the list.
106+
Returns:
107+
time-reversed sequence
108+
"""
109+
if lengths is None:
110+
return list(reversed(input_seq))
111+
112+
for input_ in input_seq:
113+
input_.set_shape(input_.get_shape().with_rank(2))
114+
115+
# Join into (time, batch_size, depth)
116+
s_joined = tf.pack(input_seq)
117+
118+
# Reverse along dimension 0
119+
s_reversed = tf.reverse_sequence(s_joined, lengths, 0, 1)
120+
# Split again into list
121+
result = tf.unpack(s_reversed)
122+
return result
123+
124+
125+
def bidirectional_rnn(cell_fw, cell_bw, inputs,
126+
initial_state_fw=None, initial_state_bw=None,
127+
dtype=None, sequence_length=None, scope=None):
128+
"""Creates a bidirectional recurrent neural network.
129+
Similar to the unidirectional case above (rnn) but takes input and builds
130+
independent forward and backward RNNs with the final forward and backward
131+
outputs depth-concatenated, such that the output will have the format
132+
[time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of
133+
forward and backward cell must match. The initial state for both directions
134+
is zero by default (but can be set optionally) and no intermediate states are
135+
ever returned -- the network is fully unrolled for the given (passed in)
136+
length(s) of the sequence(s) or completely unrolled if length(s) is not given.
137+
Args:
138+
cell_fw: An instance of RNNCell, to be used for forward direction.
139+
cell_bw: An instance of RNNCell, to be used for backward direction.
140+
inputs: A length T list of inputs, each a tensor of shape
141+
[batch_size, cell.input_size].
142+
initial_state_fw: (optional) An initial state for the forward RNN.
143+
This must be a tensor of appropriate type and shape
144+
[batch_size x cell.state_size].
145+
initial_state_bw: (optional) Same as for initial_state_fw.
146+
dtype: (optional) The data type for the initial state. Required if either
147+
of the initial states are not provided.
148+
sequence_length: (optional) An int64 vector (tensor) of size [batch_size],
149+
containing the actual lengths for each of the sequences.
150+
scope: VariableScope for the created subgraph; defaults to "BiRNN"
151+
Returns:
152+
A set of output `Tensors` where:
153+
outputs is a length T list of outputs (one for each input), which
154+
are depth-concatenated forward and backward outputs
155+
Raises:
156+
TypeError: If "cell_fw" or "cell_bw" is not an instance of RNNCell.
157+
ValueError: If inputs is None or an empty list.
158+
"""
159+
160+
if not isinstance(cell_fw, tf.nn.rnn_cell.RNNCell):
161+
raise TypeError("cell_fw must be an instance of RNNCell")
162+
if not isinstance(cell_bw, tf.nn.rnn_cell.RNNCell):
163+
raise TypeError("cell_bw must be an instance of RNNCell")
164+
if not isinstance(inputs, list):
165+
raise TypeError("inputs must be a list")
166+
if not inputs:
167+
raise ValueError("inputs must not be empty")
168+
169+
name = scope or "BiRNN"
170+
# Forward direction
171+
with tf.variable_scope(name + "_FW"):
172+
output_fw, _ = tf.nn.rnn(cell_fw, inputs, initial_state_fw, dtype,
173+
sequence_length)
174+
175+
# Backward direction
176+
with tf.variable_scope(name + "_BW"):
177+
tmp, _ = tf.nn.rnn(cell_bw, _reverse_seq(inputs, sequence_length),
178+
initial_state_bw, dtype, sequence_length)
179+
output_bw = _reverse_seq(tmp, sequence_length)
180+
# Concat each of the forward/backward outputs
181+
outputs = [tf.concat(1, [fw, bw])
182+
for fw, bw in zip(output_fw, output_bw)]
183+
184+
return outputs
185+
186+
# End of Tensorflow 0.7
187+
95188

96189
def get_rnn_model(rnn_size, cell_type, num_layers, input_op_fn,
97-
bidirection, target_predictor_fn,
190+
bidirectional, target_predictor_fn,
98191
sequence_length, initial_state):
99192
"""Returns a function that creates a RNN TensorFlow subgraph with given
100193
params.
@@ -106,13 +199,14 @@ def get_rnn_model(rnn_size, cell_type, num_layers, input_op_fn,
106199
input_op_fn: Function that will transform the input tensor, such as
107200
creating word embeddings, byte list, etc. This takes
108201
an argument X for input and returns transformed X.
109-
bidirection: Whether this is a bidirectional rnn.
202+
bidirectional: boolean, Whether this is a bidirectional rnn.
110203
target_predictor_fn: Function that will predict target from input
111204
features. This can be logistic regression,
112205
linear regression or any other model,
113206
that takes X, y and returns predictions and loss tensors.
114207
sequence_length: If sequence_length is provided, dynamic calculation is performed.
115208
This saves computational time when unrolling past max sequence length.
209+
Required for bidirectional RNNs.
116210
initial_state: An initial state for the RNN. This must be a tensor of appropriate type
117211
and shape [batch_size x cell.state_size].
118212
@@ -123,26 +217,28 @@ def rnn_estimator(X, y):
123217
"""RNN estimator with target predictor function on top."""
124218
X = input_op_fn(X)
125219
if cell_type == 'rnn':
126-
cell_fn = rnn_cell.BasicRNNCell
220+
cell_fn = tf.nn.rnn_cell.BasicRNNCell
127221
elif cell_type == 'gru':
128-
cell_fn = rnn_cell.GRUCell
222+
cell_fn = tf.nn.rnn_cell.GRUCell
129223
elif cell_type == 'lstm':
130-
cell_fn = rnn_cell.BasicLSTMCell
224+
cell_fn = tf.nn.rnn_cell.BasicLSTMCell
131225
else:
132226
raise ValueError("cell_type {} is not supported. ".format(cell_type))
133-
if bidirection:
227+
if bidirectional:
134228
# forward direction cell
135229
rnn_fw_cell = tf.nn.rnn_cell.MultiRNNCell([cell_fn(rnn_size)] * num_layers)
136230
# backward direction cell
137231
rnn_bw_cell = tf.nn.rnn_cell.MultiRNNCell([cell_fn(rnn_size)] * num_layers)
138232
# pylint: disable=unexpected-keyword-arg, no-value-for-parameter
139-
encoding = tf.nn.rnn.bidirectional_rnn(rnn_fw_cell, rnn_bw_cell,
140-
sequence_length=sequence_length,
141-
initial_state=initial_state)
233+
encoding = bidirectional_rnn(rnn_fw_cell, rnn_bw_cell, X,
234+
dtype=tf.float32,
235+
sequence_length=sequence_length,
236+
initial_state_fw=initial_state,
237+
initial_state_bw=initial_state)
142238
else:
143239
cell = tf.nn.rnn_cell.MultiRNNCell([cell_fn(rnn_size)] * num_layers)
144-
_, encoding = tf.nn.rnn.rnn(cell, X, dtype=tf.float32,
145-
sequence_length=sequence_length,
146-
initial_state=initial_state)
240+
_, encoding = tf.nn.rnn(cell, X, dtype=tf.float32,
241+
sequence_length=sequence_length,
242+
initial_state=initial_state)
147243
return target_predictor_fn(encoding[-1], y)
148244
return rnn_estimator

skflow/tests/test_nonlinear.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,26 @@ def input_fn(X):
9090
predictions = classifier.predict(np.array(list([[1, 3, 3, 2, 1],
9191
[2, 3, 4, 5, 6]])))
9292

93+
def testBidirectionalRNN(self):
94+
random.seed(42)
95+
import numpy as np
96+
data = np.array(list([[2, 1, 2, 2, 3],
97+
[2, 2, 3, 4, 5],
98+
[3, 3, 1, 2, 1],
99+
[2, 4, 5, 4, 1]]), dtype=np.float32)
100+
labels = np.array(list([1, 0, 1, 0]), dtype=np.float32)
101+
def input_fn(X):
102+
return tf.split(1, 5, X)
103+
104+
# Classification
105+
classifier = skflow.TensorFlowRNNClassifier(
106+
rnn_size=2, cell_type='lstm', n_classes=2, input_op_fn=input_fn,
107+
bidirectional=True)
108+
classifier.fit(data, labels)
109+
predictions = classifier.predict(np.array(list([[1, 3, 3, 2, 1],
110+
[2, 3, 4, 5, 6]])))
111+
self.assertAllClose(predictions, np.array([1, 0]))
112+
93113

94114
if __name__ == "__main__":
95115
tf.test.main()

0 commit comments

Comments
 (0)