Skip to content

Commit 2f7e854

Browse files
authored
Merge pull request #1033 from tensorlayer/RNN
RNN updates
2 parents 15d48b5 + 22ca25f commit 2f7e854

File tree

3 files changed

+72
-9
lines changed

3 files changed

+72
-9
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ To release a new version, please update the changelog as followed:
7979
### Deprecated
8080

8181
### Fixed
82+
- RNN updates: remove warnings, fix if seq_len=0, unitest (#PR 1033)
8283

8384
### Removed
8485

tensorlayer/layers/recurrent.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,10 @@ class RNN(Layer):
105105
Similar to the DynamicRNN in TL 1.x.
106106
107107
If the `sequence_length` is provided in RNN's forwarding and both `return_last_output` and `return_last_state`
108-
are set as `True`, the forward function will automatically ignore the paddings.
108+
are set as `True`, the forward function will automatically ignore the paddings. Note that if `return_last_output`
109+
is set as `False`, the synced sequence outputs will still include outputs which correspond with paddings,
110+
but users are free to select which slice of outputs to be used in following procedure.
111+
109112
The `sequence_length` should be a list of integers which indicates the length of each sequence.
110113
It is recommended to
111114
`tl.layers.retrieve_seq_length_op3 <https://tensorlayer.readthedocs.io/en/latest/modules/layers.html#compute-sequence-length-3>`__
@@ -244,16 +247,15 @@ def forward(self, inputs, sequence_length=None, initial_state=None, **kwargs):
244247
"but got an actual length of a sequence %d" % i
245248
)
246249

247-
sequence_length = [i - 1 for i in sequence_length]
250+
sequence_length = [i - 1 if i >= 1 else 0 for i in sequence_length]
248251

249252
# set warning
250-
if (not self.return_last_state or not self.return_last_output) and sequence_length is not None:
251-
warnings.warn(
252-
'return_last_output is set as %s ' % self.return_last_output +
253-
'and return_last_state is set as %s. ' % self.return_last_state +
254-
'When sequence_length is provided, both are recommended to set as True. ' +
255-
'Otherwise, padding will be considered while RNN is forwarding.'
256-
)
253+
# if (not self.return_last_output) and sequence_length is not None:
254+
# warnings.warn(
255+
# 'return_last_output is set as %s ' % self.return_last_output +
256+
# 'When sequence_length is provided, it is recommended to set as True. ' +
257+
# 'Otherwise, padding will be considered while RNN is forwarding.'
258+
# )
257259

258260
# return the last output, iterating each seq including padding ones. No need to store output during each
259261
# time step.
@@ -274,6 +276,7 @@ def forward(self, inputs, sequence_length=None, initial_state=None, **kwargs):
274276
self.cell.reset_recurrent_dropout_mask()
275277

276278
# recurrent computation
279+
# FIXME: if sequence_length is provided (dynamic rnn), only iterate max(sequence_length) times.
277280
for time_step in range(total_steps):
278281

279282
cell_output, states = self.cell.call(inputs[:, time_step, :], states, training=self.is_train)
@@ -758,6 +761,7 @@ def forward(self, inputs, fw_initial_state=None, bw_initial_state=None, **kwargs
758761
return outputs
759762

760763

764+
'''
761765
class ConvRNNCell(object):
762766
"""Abstract object representing an Convolutional RNN Cell."""
763767
@@ -1071,6 +1075,8 @@ def __init__(
10711075
self._add_layers(self.outputs)
10721076
self._add_params(rnn_variables)
10731077
1078+
'''
1079+
10741080

10751081
# @tf.function
10761082
def retrieve_seq_length_op(data):

tests/layers/test_layers_recurrent.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,13 @@ def setUpClass(cls):
2626
cls.hidden_size = 8
2727
cls.num_steps = 6
2828

29+
cls.data_n_steps = np.random.randint(low=cls.num_steps // 2, high=cls.num_steps + 1, size=cls.batch_size)
2930
cls.data_x = np.random.random([cls.batch_size, cls.num_steps, cls.embedding_size]).astype(np.float32)
31+
32+
for i in range(cls.batch_size):
33+
for j in range(cls.data_n_steps[i], cls.num_steps):
34+
cls.data_x[i][j][:] = 0
35+
3036
cls.data_y = np.zeros([cls.batch_size, 1]).astype(np.float32)
3137
cls.data_y2 = np.zeros([cls.batch_size, cls.num_steps]).astype(np.float32)
3238

@@ -865,6 +871,56 @@ def forward(self, x):
865871
print(output.shape)
866872
print(state)
867873

874+
def test_dynamic_rnn_with_fake_data(self):
875+
876+
class CustomisedModel(tl.models.Model):
877+
878+
def __init__(self):
879+
super(CustomisedModel, self).__init__()
880+
self.rnnlayer = tl.layers.LSTMRNN(
881+
units=8, dropout=0.1, in_channels=4, return_last_output=True, return_last_state=False
882+
)
883+
self.dense = tl.layers.Dense(in_channels=8, n_units=1)
884+
885+
def forward(self, x):
886+
z = self.rnnlayer(x, sequence_length=tl.layers.retrieve_seq_length_op3(x))
887+
z = self.dense(z[:, :])
888+
return z
889+
890+
rnn_model = CustomisedModel()
891+
print(rnn_model)
892+
optimizer = tf.optimizers.Adam(learning_rate=0.01)
893+
rnn_model.train()
894+
895+
for epoch in range(50):
896+
with tf.GradientTape() as tape:
897+
pred_y = rnn_model(self.data_x)
898+
loss = tl.cost.mean_squared_error(pred_y, self.data_y)
899+
900+
gradients = tape.gradient(loss, rnn_model.trainable_weights)
901+
optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights))
902+
903+
if (epoch + 1) % 10 == 0:
904+
print("epoch %d, loss %f" % (epoch, loss))
905+
906+
filename = "dynamic_rnn.h5"
907+
rnn_model.save_weights(filename)
908+
909+
# Testing saving and restoring of RNN weights
910+
rnn_model2 = CustomisedModel()
911+
rnn_model2.eval()
912+
pred_y = rnn_model2(self.data_x)
913+
loss = tl.cost.mean_squared_error(pred_y, self.data_y)
914+
print("MODEL INIT loss %f" % (loss))
915+
916+
rnn_model2.load_weights(filename)
917+
pred_y = rnn_model2(self.data_x)
918+
loss = tl.cost.mean_squared_error(pred_y, self.data_y)
919+
print("MODEL RESTORE W loss %f" % (loss))
920+
921+
import os
922+
os.remove(filename)
923+
868924

869925
if __name__ == '__main__':
870926

0 commit comments

Comments
 (0)