Skip to content

Commit bf8ff10

Browse files
committed
yapf format and solve travis-ci problem
1 parent ccf27f6 commit bf8ff10

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

tensorlayer/layers/recurrent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,10 @@ class RNN(Layer):
105105
Similar to the DynamicRNN in TL 1.x.
106106
107107
If the `sequence_length` is provided in RNN's forwarding and both `return_last_output` and `return_last_state`
108-
are set as `True`, the forward function will automatically ignore the paddings.
108+
are set as `True`, the forward function will automatically ignore the paddings. Note that if `return_last_output`
109+
is set as `False`, the synced sequence outputs will still include outputs which correspond with paddings,
110+
but users are free to select which slice of outputs to be used in following procedure.
111+
109112
The `sequence_length` should be a list of integers which indicates the length of each sequence.
110113
It is recommended to
111114
`tl.layers.retrieve_seq_length_op3 <https://tensorlayer.readthedocs.io/en/latest/modules/layers.html#compute-sequence-length-3>`__
@@ -1074,6 +1077,7 @@ def __init__(
10741077
10751078
'''
10761079

1080+
10771081
# @tf.function
10781082
def retrieve_seq_length_op(data):
10791083
"""An op to compute the length of a sequence from input shape of [batch_size, n_step(max), n_features],

tests/layers/test_layers_recurrent.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Layer_RNN_Test(CustomTestCase):
1818
@classmethod
1919
def setUpClass(cls):
2020

21-
cls.batch_size = 10
21+
cls.batch_size = 2
2222

2323
cls.vocab_size = 20
2424
cls.embedding_size = 4
@@ -878,9 +878,7 @@ class CustomisedModel(tl.models.Model):
878878
def __init__(self):
879879
super(CustomisedModel, self).__init__()
880880
self.rnnlayer = tl.layers.LSTMRNN(
881-
units=8, dropout=0.1, in_channels=4,
882-
return_last_output=True,
883-
return_last_state=False
881+
units=8, dropout=0.1, in_channels=4, return_last_output=True, return_last_state=False
884882
)
885883
self.dense = tl.layers.Dense(in_channels=8, n_units=1)
886884

@@ -924,7 +922,6 @@ def forward(self, x):
924922
os.remove(filename)
925923

926924

927-
928925
if __name__ == '__main__':
929926

930927
unittest.main()

0 commit comments

Comments
 (0)