Skip to content

Commit 6c3b732

Browse files
authored
Update recurrent.py
move the modification below sequence length check
1 parent 939ecd2 commit 6c3b732

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

tensorlayer/layers/recurrent.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,6 @@ def forward(self, inputs, sequence_length=None, initial_state=None, **kwargs):
215215
batch_size = inputs.get_shape().as_list()[0]
216216
total_steps = inputs.get_shape().as_list()[1]
217217

218-
# Since sequence_length is not passed into computational graph when build a static model, we force sequence_length to be not None to get dynamic RNN.
219-
# We checked that sequence_length is not passed to the model whatever it is, which induces a lower accuracy for training and validation
220-
sequence_length = tl.layers.retrieve_seq_length_op3(inputs)
221-
222218
# checking the type and values of sequence_length
223219
if sequence_length is not None:
224220
if isinstance(sequence_length, list):
@@ -251,7 +247,13 @@ def forward(self, inputs, sequence_length=None, initial_state=None, **kwargs):
251247
"but got an actual length of a sequence %d" % i
252248
)
253249

254-
sequence_length = [i - 1 if i >= 1 else 0 for i in sequence_length]
250+
'''
251+
Since sequence_length is not passed into computational graph when build a static model, we force sequence_length to be not None to get dynamic RNN.
252+
We test this code that sequence_length is not passed to the model whatever it is, which induce a lower accuracy for training and validation
253+
'''
254+
sequence_length = tl.layers.retrieve_seq_length_op3(inputs)
255+
256+
sequence_length = [i - 1 if i >= 1 else 0 for i in sequence_length]
255257

256258
# set warning
257259
# if (not self.return_last_output) and sequence_length is not None:

0 commit comments

Comments
 (0)