Skip to content

Commit 2e34583

Browse files
committed
comment warning, fix if seq_len=0
1 parent 7a28a2b commit 2e34583

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

tensorlayer/layers/recurrent.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -244,15 +244,15 @@ def forward(self, inputs, sequence_length=None, initial_state=None, **kwargs):
244244
"but got an actual length of a sequence %d" % i
245245
)
246246

247-
sequence_length = [i - 1 for i in sequence_length]
247+
sequence_length = [i - 1 if i >= 1 else 0 for i in sequence_length]
248248

249249
# set warning
250-
if (not self.return_last_output) and sequence_length is not None:
251-
warnings.warn(
252-
'return_last_output is set as %s ' % self.return_last_output +
253-
'When sequence_length is provided, it is recommended to set as True. ' +
254-
'Otherwise, padding will be considered while RNN is forwarding.'
255-
)
250+
# if (not self.return_last_output) and sequence_length is not None:
251+
# warnings.warn(
252+
# 'return_last_output is set as %s ' % self.return_last_output +
253+
# 'When sequence_length is provided, it is recommended to set as True. ' +
254+
# 'Otherwise, padding will be considered while RNN is forwarding.'
255+
# )
256256

257257
# return the last output, iterating each seq including padding ones. No need to store output during each
258258
# time step.
@@ -273,6 +273,7 @@ def forward(self, inputs, sequence_length=None, initial_state=None, **kwargs):
273273
self.cell.reset_recurrent_dropout_mask()
274274

275275
# recurrent computation
276+
# FIXME: if sequence_length is provided (dynamic rnn), only iterate max(sequence_length) times.
276277
for time_step in range(total_steps):
277278

278279
cell_output, states = self.cell.call(inputs[:, time_step, :], states, training=self.is_train)

0 commit comments

Comments
 (0)