@@ -244,15 +244,15 @@ def forward(self, inputs, sequence_length=None, initial_state=None, **kwargs):
244
244
"but got an actual length of a sequence %d" % i
245
245
)
246
246
247
- sequence_length = [i - 1 for i in sequence_length ]
247
+ sequence_length = [i - 1 if i >= 1 else 0 for i in sequence_length ]
248
248
249
249
# set warning
250
- if (not self .return_last_output ) and sequence_length is not None :
251
- warnings .warn (
252
- 'return_last_output is set as %s ' % self .return_last_output +
253
- 'When sequence_length is provided, it is recommended to set as True. ' +
254
- 'Otherwise, padding will be considered while RNN is forwarding.'
255
- )
250
+ # if (not self.return_last_output) and sequence_length is not None:
251
+ # warnings.warn(
252
+ # 'return_last_output is set as %s ' % self.return_last_output +
253
+ # 'When sequence_length is provided, it is recommended to set as True. ' +
254
+ # 'Otherwise, padding will be considered while RNN is forwarding.'
255
+ # )
256
256
257
257
# return the last output, iterating each seq including padding ones. No need to store output during each
258
258
# time step.
@@ -273,6 +273,7 @@ def forward(self, inputs, sequence_length=None, initial_state=None, **kwargs):
273
273
self .cell .reset_recurrent_dropout_mask ()
274
274
275
275
# recurrent computation
276
+ # FIXME: if sequence_length is provided (dynamic rnn), only iterate max(sequence_length) times.
276
277
for time_step in range (total_steps ):
277
278
278
279
cell_output , states = self .cell .call (inputs [:, time_step , :], states , training = self .is_train )
0 commit comments